Quantifying uncertainty in prediction¶

Why study uncertainty in prediction in the first place?¶

A couple of reasons come to mind. Starting with the directions of this project: what if our model makes incorrect predictions on data outside the distribution they were trained on. This is called extrapolation and would mean that the model doesn't generalize well. But without a measure to quantify this, should we trust the model blindly?

This gives a second reason: the model itself can give this measure of uncertainty (or confidence) with it's predictions, which would make the analysis process easier.

How to quantify this?¶

One such measure is entropy. For a random variable X it would be the amount of information required on average to encode it. There is quite the theory behind this, but the short story is that when we get an expected piece of information -> we become more certain -> large probability and low entropy. But, in a 50-50 scenario type information (or uniform distribution at that), the uncertainty would be the highest -> lowest probability and high entropy. The trick in implementing this in the following cases would be to find a threshold for this uncertainty (meaning, given a new sample from the distribution inferred my the model, what would be considered outliers?).

For a classification problem, the predictive posterior is a categorical distribution over the possible classes and we would be interested in the entropy of this distribution. For a regression problem, if we think of the target distribution as a sum of more iid distributions, than we would be interested in the entropy of this, which is the entropy of a Gaussian.

The theory behind this is quite complex and big, but I tried my best with the 2 use cases below:

Initial Imports¶

In [127]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import math
import pymc as pm
import numpy as np
import pytensor.tensor as pt
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models
import pytensor.tensor as pt
import arviz as az
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
import seaborn as sns
from scipy.stats import entropy
from sklearn.metrics import mean_absolute_error, mean_squared_error
In [65]:
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

if tf.test.gpu_device_name():
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    print("Please install GPU version of TF")
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Default GPU Device: /device:GPU:0

Data Exploration¶

For the classification problem, I chose a dataset of images from 100 sports. Thus, this is an image classification problem. Without treating it as a black-box type of challenge, there are many factors that can influence these pictures according to their context, such as:

  • football images may have more green in them with a single round object
  • formula 1 racing images should be one of the few categories to have cars in images

Path Setup¶

In [66]:
classification_data_directory = "classification_data"
regression_data_directory = os.path.join("regression_data", 'housing.csv')
train_directory = os.path.join(classification_data_directory, 'train')
valid_directory = os.path.join(classification_data_directory, 'valid')
test_directory = os.path.join(classification_data_directory, 'test')

Classification data - 100 sports¶

In [67]:
sport_types_train = [type for type in os.listdir(train_directory) if os.path.isdir(os.path.join(train_directory, type))]
sport_types_valid = [type for type in os.listdir(valid_directory) if os.path.isdir(os.path.join(valid_directory, type))]
sport_types_test = [type for type in os.listdir(test_directory) if os.path.isdir(os.path.join(test_directory, type))]
sport_types = list(set(sport_types_train + sport_types_valid + sport_types_test))

sport_types
Out[67]:
['tennis',
 'balance beam',
 'hockey',
 'axe throwing',
 'archery',
 'judo',
 'hammer throw',
 'figure skating pairs',
 'bull riding',
 'speed skating',
 'pole vault',
 'horse jumping',
 'sidecar racing',
 'skydiving',
 'baton twirling',
 'rock climbing',
 'field hockey',
 'croquet',
 'sumo wrestling',
 'volleyball',
 'water polo',
 'tug of war',
 'rollerblade racing',
 'barell racing',
 'bobsled',
 'ice yachting',
 'boxing',
 'ampute football',
 'frisbee',
 'mushing',
 'pole dancing',
 'trapeze',
 'uneven bars',
 'horse racing',
 'canoe slamon',
 'roller derby',
 'lacrosse',
 'swimming',
 'harness racing',
 'jousting',
 'steer wrestling',
 'weightlifting',
 'golf',
 'wheelchair basketball',
 'bike polo',
 'basketball',
 'wheelchair racing',
 'formula 1 racing',
 'curling',
 'polo',
 'chuckwagon racing',
 'billiards',
 'rugby',
 'bungee jumping',
 'jai alai',
 'table tennis',
 'pole climbing',
 'fencing',
 'track bicycle',
 'sky surfing',
 'bmx',
 'snowmobile racing',
 'football',
 'ski jumping',
 'nascar racing',
 'figure skating men',
 'snow boarding',
 'arm wrestling',
 'gaga',
 'luge',
 'shuffleboard',
 'olympic wrestling',
 'horseshoe pitching',
 'log rolling',
 'bowling',
 'javelin',
 'giant slalom',
 'sailboat racing',
 'cricket',
 'disc golf',
 'surfing',
 'rowing',
 'rings',
 'air hockey',
 'parallel bar',
 'baseball',
 'high jump',
 'hang gliding',
 'hydroplane racing',
 'ultimate',
 'pommel horse',
 'ice climbing',
 'motorcycle racing',
 'cheerleading',
 'wingsuit flying',
 'figure skating women',
 'water cycling',
 'fly fishing',
 'shot put',
 'hurdles']

Number of images per label¶

Pretty much shows the raw set was inbalanced

In [68]:
def count_images_in_folder(folder_path):
    return len([f for f in os.listdir(folder_path) if f.endswith('.jpg')])

data = []
for type in sport_types:
    train_count = count_images_in_folder(os.path.join(train_directory, type))
    valid_count = count_images_in_folder(os.path.join(valid_directory, type))
    test_count = count_images_in_folder(os.path.join(test_directory, type))
    data.append({'Type': type, 'Train': train_count, 'Test': test_count, 'Valid': valid_count})

sport_types_statistics = pd.DataFrame(data)
sport_types_statistics
Out[68]:
Type Train Test Valid
0 tennis 131 5 5
1 balance beam 147 5 5
2 hockey 172 5 5
3 axe throwing 113 5 5
4 archery 132 5 5
... ... ... ... ...
95 figure skating women 157 5 5
96 water cycling 103 5 5
97 fly fishing 134 5 5
98 shot put 149 5 5
99 hurdles 136 5 5

100 rows × 4 columns

In [69]:
sport_types_statistics.set_index('Type')[['Train', 'Valid', 'Test']].plot(kind='bar', figsize=(50, 30))
plt.title('Number of Images per Type in Train, Valid and Test Set')
plt.ylabel('Number of Images')
plt.xlabel('Sport Type')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
No description has been provided for this image

Common image size¶

In [70]:
def get_image_size(image_path):
    with Image.open(image_path) as img:
        return img.size  

image_sizes = []
for folder in [train_directory, valid_directory, test_directory]:
    for type in os.listdir(folder):
        sport_type_dir = os.path.join(folder, type)
        if os.path.isdir(sport_type_dir):
            for image_file in os.listdir(sport_type_dir):
                if image_file.endswith('.jpg'):
                    image_path = os.path.join(sport_type_dir, image_file)
                    image_sizes.append(get_image_size(image_path))

sizes_dataframe = pd.DataFrame(image_sizes, columns=['Width', 'Height'])
sizes_dataframe['Width'].unique(), sizes_dataframe['Height'].unique()
Out[70]:
(array([224], dtype=int64), array([224], dtype=int64))
In [71]:
IMAGE_SIZE = (224, 224)
IMAGE_SHAPE = (224, 224, 3)

Visualize samples from each label¶

In [72]:
num_classes = len(sport_types_train)
grid_cols = math.ceil(math.sqrt(num_classes))  
grid_rows = math.ceil(num_classes / grid_cols)  

fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(15, 15))
axes = axes.flatten()  

for i, sport_type in enumerate(sport_types_train):
    sport_type_dir = os.path.join(train_directory, sport_type)
    image_files = [f for f in os.listdir(sport_type_dir) if f.endswith('.jpg')]
    
    if image_files:  
        image_path = os.path.join(sport_type_dir, image_files[0])
        with Image.open(image_path) as img:
            axes[i].imshow(np.array(img.resize(IMAGE_SIZE))) 
            axes[i].set_title(sport_type)
            axes[i].axis('off')
    else:
        axes[i].axis('off')  

for i in range(len(sport_types_train), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image

Preparing the data¶

Classification data (choose train classes + unseen class)¶

I chose 10 classes for training, with the intention of taking ~ 100 samples each -> 1000 training samples overall. The labels are represented by the folder in which the sample image is, so we should have the first 10 alphabetical labels:

  • 'air hockey',
  • 'ampute football',
  • 'archery',
  • 'arm wrestling',
  • 'axe throwing',
  • 'balance beam',
  • 'barell racing',
  • 'baseball',
  • 'basketball',
  • 'baton twirling'

And the unseen class should be: bike polo

In [73]:
CLASS_COUNT = 10
train_classes = sport_types_train[:CLASS_COUNT]
train_classes
Out[73]:
['air hockey',
 'ampute football',
 'archery',
 'arm wrestling',
 'axe throwing',
 'balance beam',
 'barell racing',
 'baseball',
 'basketball',
 'baton twirling']
In [74]:
unseen_class = sport_types_train[CLASS_COUNT]
unseen_class
Out[74]:
'bike polo'
In [75]:
class_to_label = {sport: idx for idx, sport in enumerate(train_classes)}
class_to_label[unseen_class] = CLASS_COUNT
class_to_label
Out[75]:
{'air hockey': 0,
 'ampute football': 1,
 'archery': 2,
 'arm wrestling': 3,
 'axe throwing': 4,
 'balance beam': 5,
 'barell racing': 6,
 'baseball': 7,
 'basketball': 8,
 'baton twirling': 9,
 'bike polo': 10}

Subset the data (earlier attempts took too long with all data)¶

I have 2 options here:

  • use a fraction of the given data (attempted this because previous chain sampling took very long locally)
  • use the whole data within a given limit (here I set 100 as max number of images per label, because this is the minimum between the labels - this is an unbalanced set)
In [76]:
def prepare_data(classes, data_dir):
    data = []
    labels = []
    for sport_type in classes:
        sport_type_dir = os.path.join(data_dir, sport_type)
        image_files = [os.path.join(sport_type_dir, f) for f in os.listdir(sport_type_dir) if f.endswith('.jpg')]
        
        # Assign the correct label using the mapping
        label = class_to_label[sport_type]
        
        data.extend(image_files)
        labels.extend([label] * len(image_files))
    
    return data, labels
In [77]:
def subset_data(data, labels, fraction=0.1):
    np.random.seed(42)
    indices = np.random.choice(len(data), size=int(len(data) * fraction), replace=False)
    return np.array(data)[indices], np.array(labels)[indices]
In [78]:
SUBSET_FRACTION = 0.15
TRAIN_IMAGES_RAW_LIMIT = 99
train_data = []
train_labels =[]
valid_data = []
valid_labels = []
test_data = []
test_labels = []
train_subset_data = []
train_subset_labels = []
valid_subset_data = []
valid_subset_labels = []
test_subset_data = []
test_subset_labels = []

for train_class in train_classes:
    train_class_data, train_class_labels = prepare_data([train_class], train_directory)
    train_class_data = train_class_data[:TRAIN_IMAGES_RAW_LIMIT]
    train_class_labels = train_class_labels[:TRAIN_IMAGES_RAW_LIMIT]
    train_class_subset_data, train_class_subset_labels = subset_data(train_class_data, train_class_labels, fraction=SUBSET_FRACTION)
    train_data.extend(train_class_data)
    train_labels.extend(train_class_labels)
    train_subset_data.extend(train_class_subset_data)
    train_subset_labels.extend(train_class_subset_labels)
    
    valid_class_data, valid_class_labels = prepare_data([train_class], valid_directory)
    valid_class_subset_data, valid_class_subset_labels = subset_data(valid_class_data, valid_class_labels, fraction=2*SUBSET_FRACTION)
    valid_data.extend(valid_class_data)
    valid_labels.extend(valid_class_labels)
    valid_subset_data.extend(valid_class_subset_data)
    valid_subset_labels.extend(valid_class_subset_labels)

    test_class_data, test_class_labels = prepare_data([train_class], test_directory)
    test_class_subset_data, test_class_subset_labels = subset_data(test_class_data, test_class_labels, fraction=2*SUBSET_FRACTION)
    test_data.extend(test_class_data)
    test_labels.extend(test_class_labels)
    test_subset_data.extend(test_class_subset_data)
    test_subset_labels.extend(test_class_subset_labels)
   
combined_test_data = np.concatenate([test_data, valid_data])
combined_test_labels = np.concatenate([test_labels, valid_labels])    
combined_test_subset_data = np.concatenate([test_subset_data, valid_subset_data])
combined_test_subset_labels = np.concatenate([test_subset_labels, valid_subset_labels])

print("Training Data Shape:", len(train_data))
print("Training Subset Data Shape:", len(train_subset_data))
print("Testing Data Shape:", len(combined_test_data))
print("Testing Subset Data Shape:", len(combined_test_subset_data))
Training Data Shape: 990
Training Subset Data Shape: 140
Testing Data Shape: 100
Testing Subset Data Shape: 20
In [79]:
unseen_train_data, unseen_train_labels = prepare_data([unseen_class], train_directory)
unseen_train_data = unseen_train_data[:TRAIN_IMAGES_RAW_LIMIT]
unseen_train_labels = unseen_train_labels[:TRAIN_IMAGES_RAW_LIMIT]
unseen_train_subset_data, unseen_train_subset_labels = subset_data(unseen_train_data, unseen_train_labels, fraction=SUBSET_FRACTION)

unseen_valid_data, unseen_valid_labels = prepare_data([unseen_class], valid_directory)
unseen_valid_subset_data, unseen_valid_subset_labels = subset_data(unseen_valid_data, unseen_valid_labels, fraction=2*SUBSET_FRACTION)
unseen_test_data, unseen_test_labels = prepare_data([unseen_class], test_directory)
unseen_test_subset_data, unseen_test_subset_labels = subset_data(unseen_test_data, unseen_test_labels, fraction=2*SUBSET_FRACTION)

combined_unseen_test_data = np.concatenate([unseen_test_data, unseen_valid_data])
combined_unseen_test_subset_data = np.concatenate([unseen_test_subset_data, unseen_valid_subset_data])
combined_unseen_test_labels = np.concatenate([unseen_test_labels, unseen_valid_labels])
combined_unseen_test_subset_labels = np.concatenate([unseen_test_subset_labels, unseen_valid_subset_labels])

print("Unseen Training Data Shape:", len(unseen_train_data))
print("Unseen Training Subset Data Shape:", len(unseen_train_subset_data))
print("Unseen Testing Data Shape:", len(combined_unseen_test_data))
print("Unseen Testing Subset Data Shape:", len(combined_unseen_test_subset_data))
Unseen Training Data Shape: 99
Unseen Training Subset Data Shape: 14
Unseen Testing Data Shape: 10
Unseen Testing Subset Data Shape: 2

Feature extraction using ResNet50 + PCA¶

Used pretrained model to extract vectorized embeddings. Because initial vector size was 2048, it made the whole process a lot slower (not only that, but because of local limitations, features had to be extracted in batches).

Because of this, I decided to use PCA, which is a compression method, to keep the most relevant features.

In [80]:
feature_extractor = tf.keras.applications.ResNet50(
    include_top=False, input_shape=IMAGE_SHAPE, pooling='avg', weights='imagenet'
)
In [81]:
def extract_features(image_paths):
    images = [
        tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(p)), IMAGE_SIZE) 
        for p in image_paths
    ]
    images = tf.stack(images) / 255.0 
    return feature_extractor(images).numpy()
In [82]:
def extract_features_in_batches(image_paths, batch_size=32):
    features = []
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = [
            tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(p)), IMAGE_SIZE) 
            for p in batch_paths
        ]
        batch_images = tf.stack(batch_images) / 255.0 
        batch_features = feature_extractor(batch_images).numpy()
        features.append(batch_features)
    return np.concatenate(features, axis=0)
In [83]:
train_features = extract_features_in_batches(train_data) 
test_features = extract_features_in_batches(combined_test_data)
unseen_train_features = extract_features_in_batches(unseen_train_data)
unseen_test_features = extract_features_in_batches(combined_unseen_test_data)
train_features.shape, test_features.shape, unseen_train_features.shape, unseen_test_features.shape
Out[83]:
((990, 2048), (100, 2048), (99, 2048), (10, 2048))
In [84]:
pca = PCA(n_components=99)
In [85]:
train_features_pca = pca.fit_transform(train_features)
test_features_pca = pca.transform(test_features)
unseen_train_features_pca = pca.fit_transform(unseen_train_features)
unseen_test_features_pca = pca.transform(unseen_test_features)
train_features_pca.shape, test_features_pca.shape, unseen_train_features_pca.shape, unseen_test_features_pca.shape
Out[85]:
((990, 99), (100, 99), (99, 99), (10, 99))

Bayesian Neural Network¶

Setting the architecture and the utility functions¶

There were some good pieces of documentation / info out there, such as:

  • PYMC Officials Docs on Bayesian Neural Networks
  • Stackoverflow problem on IRIS dataset
  • PYMC YT video for Biotech (this was more for me to get used to some structures)
In [86]:
def relu(x):
  return pt.switch(pt.lt(x, 0), 0, x)

def softmax(x):
    exp_x = pt.exp(x - pt.max(x, axis=1, keepdims=True)) 
    return exp_x / pt.sum(exp_x, axis=1, keepdims=True)

def multiclass_bayesian_network(input_data, output_data, hidden_variables_count=128, num_classes_param=None):
    if num_classes is None:
        raise ValueError("num_classes must be specified")

    with pm.Model() as bnn_model:
        features_data = pm.Data("input_data", input_data)
        labels_data  = pm.Data("output_data", output_data)

        weights_1 = pm.Normal("W1", mu=0, sigma=0.5, shape=(features_data.shape[1], hidden_variables_count))
        bias_1 = pm.Normal("b1", mu=0, sigma=0.5, shape=(hidden_variables_count,))
        hidden = relu(pt.dot(features_data, weights_1) + bias_1)

        weights_2 = pm.Normal("W2", mu=0, sigma=0.5, shape=(hidden_variables_count, num_classes_param))
        bias_2 = pm.Normal("b2", mu=0, sigma=0.5, shape=(num_classes_param,))
        output_logits = pt.dot(hidden, weights_2) + bias_2
        output_probs = pm.Deterministic("output_probs", softmax(output_logits))

        y_obs = pm.Categorical("y_obs", p=output_probs, observed=labels_data)
    
    return bnn_model
In [113]:
multiclass_bayesian_model = multiclass_bayesian_network(train_features_pca, train_labels, hidden_variables_count=192, num_classes_param=CLASS_COUNT+1)
In [114]:
with multiclass_bayesian_model:
    approx = pm.fit(n = 50000)

bayesian_model_trace = approx.sample(draws = 5000)
Output()

Finished [100%]: Average Loss = 2,964.4
In [115]:
plt.plot(approx.hist)
plt.title("Loss History during ADVI")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()
No description has been provided for this image

Generating predictions on posterior and evaluating it on the training set¶

That would be the whole purpose and by comparing the predictions, we get a better idea of the uncertainties (it makes more sens on unseen/testing data) This section's evaluation will be on the training set

In [139]:
with multiclass_bayesian_model:
    posterior_predictions = pm.sample_posterior_predictive(bayesian_model_trace, var_names=["output_probs"])

posterior_probabilities = posterior_predictions.posterior_predictive["output_probs"]
Sampling: [W1]
Output()

In [117]:
posterior_probabilities_combined = posterior_probabilities.values.reshape(
    -1, 
    posterior_probabilities.shape[2],
    posterior_probabilities.shape[3]  
)
posterior_probabilities_combined.shape
Out[117]:
(5000, 990, 11)
In [118]:
mean_probabilities = posterior_probabilities_combined.mean(axis=0)
mean_probabilities.shape
Out[118]:
(990, 11)
In [119]:
predicted_classes = np.argmax(mean_probabilities, axis=1)
predicted_classes.shape
Out[119]:
(990,)
In [120]:
predicted_classes = np.argmax(mean_probabilities, axis=1)
accuracy = accuracy_score(train_labels, predicted_classes)
print(f"BNN Training Set Accuracy: {accuracy * 100:.2f}%")
BNN Training Set Accuracy: 60.81%

Visualize the weights posterior distributions¶

In [121]:
weights_1_samples = bayesian_model_trace.posterior["W1"].values
weights_2_samples = bayesian_model_trace.posterior["W2"].values

weights_1_flat = weights_1_samples.reshape(-1, weights_1_samples.shape[-1])
weights_2_flat = weights_2_samples.reshape(-1, weights_2_samples.shape[-1])

plt.figure(figsize=(12, 6))
plt.hist(weights_1_flat[:, 0], bins=30, density=True, alpha=0.7, label="W1[0]")
plt.hist(weights_1_flat[:, 1], bins=30, density=True, alpha=0.7, label="W1[1]")
plt.title("Posterior Distribution of W1")
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
plt.hist(weights_2_flat[:, 0], bins=30, density=True, alpha=0.7, label="W2[0]")
plt.hist(weights_2_flat[:, 1], bins=30, density=True, alpha=0.7, label="W2[1]")
plt.title("Posterior Distribution of W2")
plt.legend()
plt.show()
No description has been provided for this image
No description has been provided for this image

Evaluate on the testing set¶

In [122]:
def entropy(probabilities):
    return -np.sum(probabilities * np.log(probabilities + 1e-10), axis=1)
In [171]:
with multiclass_bayesian_model:
    pm.set_data({"input_data": test_features_pca, "output_data": combined_test_labels})
    posterior_pred_test = pm.sample_posterior_predictive(bayesian_model_trace, var_names=["output_probs"])
Sampling: [W1]
Output()

In [172]:
test_mean_probs = np.mean(np.squeeze(posterior_pred_test.posterior_predictive["output_probs"]), axis=0) 

test_epistemic_uncertainty = np.var(np.squeeze(posterior_pred_test.posterior_predictive["output_probs"]), axis=0) 
test_aleatoric_uncertainty = entropy(test_mean_probs.T)

test_max_epistemic_uncertainty = test_epistemic_uncertainty.max(axis=0) 
test_novel_classes = test_epistemic_uncertainty.max(axis=1) > 0.1

print("Novel classes detected:", test_novel_classes)
print("Mean probabilities:", test_mean_probs[0])
print("Epistemic uncertainty:", test_epistemic_uncertainty[0])
print("Aleatoric uncertainty:", test_aleatoric_uncertainty)
Novel classes detected: <xarray.DataArray 'output_probs' (output_probs_dim_2: 100)> Size: 100B
array([ True,  True, False, False,  True,  True, False, False,  True,
       False,  True, False, False, False, False, False, False, False,
       False, False, False, False, False,  True, False, False, False,
       False, False,  True, False, False,  True, False, False, False,
        True, False, False, False, False,  True, False,  True,  True,
        True,  True, False,  True, False,  True,  True, False,  True,
       False, False, False,  True, False, False,  True, False, False,
       False, False,  True, False,  True,  True, False, False, False,
       False, False, False, False, False,  True, False,  True, False,
       False, False, False, False, False,  True, False, False, False,
        True,  True,  True, False,  True, False, False, False, False,
       False])
Coordinates:
    chain               int32 4B 0
  * output_probs_dim_2  (output_probs_dim_2) int32 400B 0 1 2 3 ... 96 97 98 99
Mean probabilities: <xarray.DataArray 'output_probs' (output_probs_dim_3: 11)> Size: 88B
array([0.12047857, 0.03830217, 0.08886912, 0.098767  , 0.27824841,
       0.10374331, 0.03321068, 0.06162477, 0.08236284, 0.08791992,
       0.0064732 ])
Coordinates:
    chain               int32 4B 0
    output_probs_dim_2  int32 4B 0
  * output_probs_dim_3  (output_probs_dim_3) int32 44B 0 1 2 3 4 5 6 7 8 9 10
Epistemic uncertainty: <xarray.DataArray 'output_probs' (output_probs_dim_3: 11)> Size: 88B
array([0.08063358, 0.0250988 , 0.05855689, 0.06815269, 0.15466592,
       0.07055334, 0.02287705, 0.04123484, 0.05610241, 0.0581371 ,
       0.00429029])
Coordinates:
    chain               int32 4B 0
    output_probs_dim_2  int32 4B 0
  * output_probs_dim_3  (output_probs_dim_3) int32 44B 0 1 2 3 4 5 6 7 8 9 10
Aleatoric uncertainty: [2.15151485 1.98659122 2.22919085 2.28322409 2.21591882 2.13388424
 2.278722   2.27254052 2.14835955 2.27965069 2.19427503 2.28711445
 2.28170678 2.27677932 2.28533225 2.24011298 2.27281038 2.25423949
 2.28084961 2.27865669 2.26559945 2.28341933 2.28511396 2.24595713
 2.22066714 2.28386501 2.27431511 2.25143691 2.28624264 2.08353971
 2.2642976  2.27597631 2.24453399 2.27195307 2.25222716 2.27234776
 2.24218153 2.28227424 2.27347661 2.27864978 2.27934771 2.13495062
 2.2802027  2.16445266 2.21796425 2.20132708 2.24894247 2.27944628
 2.23819392 2.28021848 2.19692635 2.23800513 2.27563449 2.06666598
 2.25594406 2.27809693 2.26270891 2.15740496 2.27463285 2.21033119
 2.20741521 2.20712788 2.27846531 2.2769733  2.27193023 2.24362428
 2.2761692  2.22579178 2.23604134 2.26065427 2.28242154 2.25256941
 2.2852577  2.23758645 2.28425505 2.28228572 2.28051113 2.20228636
 2.26045148 2.00690108 2.20345334 2.26392255 2.25764582 2.17487155
 2.21291318 2.28168361 2.23947088 2.25990939 2.27341316 2.27051518
 2.22023712 2.15661177 2.0724298  2.24527942 2.22802674 2.27827164
 2.25267702 2.27937073 2.27881174 2.21320226]
In [173]:
test_probs = posterior_pred_test.posterior_predictive["output_probs"].mean(axis=(0, 1)).values
test_uncertainty = posterior_pred_test.posterior_predictive["output_probs"].std(axis=(0, 1)).values

test_predictions = np.argmax(test_probs, axis=1)
accuracy = accuracy_score(combined_test_labels, test_predictions)
print(f"BNN Testing Set Accuracy: {accuracy * 100:.2f}%")
BNN Testing Set Accuracy: 9.00%
In [124]:
test_entropies = entropy(test_probs)
entropy_threshold = test_entropies.mean() + 2 * test_entropies.std()

plt.figure(figsize=(10, 6))
plt.hist(test_entropies, bins=30, alpha=0.7, label="Known Classes Entropy")
plt.axvline(entropy_threshold, color="red", linestyle="--", label="Entropy Threshold")
plt.legend()
plt.title("Entropy Distribution")
plt.xlabel("Entropy")
plt.ylabel("Frequency")
plt.show()
No description has been provided for this image

Then evaluate it for the unseen class set¶

The key is to use the uncertainty in the model's predictions to identify whether an example belongs to a class that the model has seen during training. Typically, a "high uncertainty" prediction means the model is not confident, which could indicate that the example is from an unseen class.

Entropy: One approach is to measure the uncertainty by computing the entropy of the predicted class probabilities. If the entropy is high, it suggests the model is uncertain, which could indicate an unseen class.

In [174]:
with multiclass_bayesian_model:
    pm.set_data({"input_data": unseen_train_features_pca, "output_data": unseen_train_labels})
    posterior_pred_unseen  = pm.sample_posterior_predictive(bayesian_model_trace, var_names=["output_probs"])
Sampling: [W1]
Output()

In [175]:
unseen_mean_probs = np.mean(np.squeeze(posterior_pred_unseen.posterior_predictive["output_probs"]), axis=0) 

unseen_epistemic_uncertainty = np.var(np.squeeze(posterior_pred_unseen.posterior_predictive["output_probs"]), axis=0) 
unseen_aleatoric_uncertainty = entropy(test_mean_probs.T)

unseen_max_epistemic_uncertainty = unseen_epistemic_uncertainty.max(axis=0) 
unseen_novel_classes = unseen_epistemic_uncertainty.max(axis=1) > 0.1

print("Novel classes detected:", unseen_novel_classes)
print("Mean probabilities:", unseen_mean_probs[0])
print("Epistemic uncertainty:", unseen_epistemic_uncertainty[0])
print("Aleatoric uncertainty:", unseen_aleatoric_uncertainty)
Novel classes detected: <xarray.DataArray 'output_probs' (output_probs_dim_2: 99)> Size: 99B
array([ True, False, False,  True, False,  True, False, False, False,
        True,  True, False,  True,  True, False, False, False, False,
       False,  True,  True, False, False, False,  True,  True,  True,
        True, False, False, False,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True, False,  True,
        True,  True, False,  True,  True,  True, False,  True,  True,
       False, False, False, False, False, False,  True, False, False,
       False, False,  True,  True,  True, False, False, False,  True,
        True, False, False, False,  True,  True,  True, False, False,
        True,  True,  True, False, False,  True,  True, False, False,
       False,  True, False,  True, False,  True, False,  True, False])
Coordinates:
    chain               int32 4B 0
  * output_probs_dim_2  (output_probs_dim_2) int32 396B 0 1 2 3 ... 95 96 97 98
Mean probabilities: <xarray.DataArray 'output_probs' (output_probs_dim_3: 11)> Size: 88B
array([0.12480786, 0.05349336, 0.10452001, 0.08237487, 0.22778988,
       0.0880089 , 0.04863846, 0.07709501, 0.09191026, 0.09594223,
       0.00541915])
Coordinates:
    chain               int32 4B 0
    output_probs_dim_2  int32 4B 0
  * output_probs_dim_3  (output_probs_dim_3) int32 44B 0 1 2 3 4 5 6 7 8 9 10
Epistemic uncertainty: <xarray.DataArray 'output_probs' (output_probs_dim_3: 11)> Size: 88B
array([0.07005756, 0.02692948, 0.05646839, 0.04880137, 0.1138918 ,
       0.05078056, 0.02589244, 0.0393869 , 0.05180485, 0.05265538,
       0.00243326])
Coordinates:
    chain               int32 4B 0
    output_probs_dim_2  int32 4B 0
  * output_probs_dim_3  (output_probs_dim_3) int32 44B 0 1 2 3 4 5 6 7 8 9 10
Aleatoric uncertainty: [2.15151485 1.98659122 2.22919085 2.28322409 2.21591882 2.13388424
 2.278722   2.27254052 2.14835955 2.27965069 2.19427503 2.28711445
 2.28170678 2.27677932 2.28533225 2.24011298 2.27281038 2.25423949
 2.28084961 2.27865669 2.26559945 2.28341933 2.28511396 2.24595713
 2.22066714 2.28386501 2.27431511 2.25143691 2.28624264 2.08353971
 2.2642976  2.27597631 2.24453399 2.27195307 2.25222716 2.27234776
 2.24218153 2.28227424 2.27347661 2.27864978 2.27934771 2.13495062
 2.2802027  2.16445266 2.21796425 2.20132708 2.24894247 2.27944628
 2.23819392 2.28021848 2.19692635 2.23800513 2.27563449 2.06666598
 2.25594406 2.27809693 2.26270891 2.15740496 2.27463285 2.21033119
 2.20741521 2.20712788 2.27846531 2.2769733  2.27193023 2.24362428
 2.2761692  2.22579178 2.23604134 2.26065427 2.28242154 2.25256941
 2.2852577  2.23758645 2.28425505 2.28228572 2.28051113 2.20228636
 2.26045148 2.00690108 2.20345334 2.26392255 2.25764582 2.17487155
 2.21291318 2.28168361 2.23947088 2.25990939 2.27341316 2.27051518
 2.22023712 2.15661177 2.0724298  2.24527942 2.22802674 2.27827164
 2.25267702 2.27937073 2.27881174 2.21320226]
In [ ]:
unseen_probs  = posterior_pred_test.posterior_predictive["output_probs"].mean(axis=(0, 1)).values
unseen_uncertainty  = posterior_pred_test.posterior_predictive["output_probs"].std(axis=(0, 1)).values

unseen_predictions = np.argmax(unseen_probs, axis=1)
accuracy = accuracy_score(unseen_train_labels, unseen_predictions[:TRAIN_IMAGES_RAW_LIMIT])
print(f"BNN Testing Set Accuracy: {accuracy * 100:.2f}%")
In [126]:
unseen_entropies = entropy(unseen_probs)
unseen_entropy_threshold = unseen_entropies.mean() + 2 * unseen_entropies.std()

plt.figure(figsize=(10, 6))
plt.hist(unseen_entropies, bins=30, alpha=0.7, label="Unseen Class Entropy")
plt.axvline(unseen_entropy_threshold, color="red", linestyle="--", label="Entropy Threshold")
plt.legend()
plt.title("Entropy Distribution")
plt.xlabel("Entropy")
plt.ylabel("Frequency")
plt.show()
No description has been provided for this image

Using classical CNNs¶

Setting the architecture¶

Will use a convolutional neural network for this multiclass image classification problem -> done using tensorflow

In [44]:
def build_cnn_model(input_shape, num_classes_param):
    model = models.Sequential()
    model.add(layers.InputLayer(input_shape=input_shape))
    model.add(layers.Reshape((input_shape[0], 1))) 
    model.add(layers.Conv1D(64, 3, activation='relu'))
    model.add(layers.MaxPooling1D(2))
    model.add(layers.Conv1D(128, 3, activation='relu'))
    model.add(layers.MaxPooling1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(num_classes_param, activation='softmax'))  
    return model

Building the CNN model and train it¶

The same data will be used for a fair comparison

In [45]:
multiclass_cnn_model = build_cnn_model((train_features_pca.shape[1],), CLASS_COUNT)

multiclass_cnn_model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])

multiclass_cnn_model_history = multiclass_cnn_model.fit(
    train_features_pca, 
    np.array(train_labels), 
    epochs=20, 
    batch_size=32, 
    validation_data=(test_features_pca, np.array(combined_test_labels)))
Epoch 1/20
31/31 [==============================] - 2s 19ms/step - loss: 2.2813 - accuracy: 0.1657 - val_loss: 2.1814 - val_accuracy: 0.2600
Epoch 2/20
31/31 [==============================] - 0s 6ms/step - loss: 2.1372 - accuracy: 0.2313 - val_loss: 1.8984 - val_accuracy: 0.4000
Epoch 3/20
31/31 [==============================] - 0s 6ms/step - loss: 1.9293 - accuracy: 0.3384 - val_loss: 1.6939 - val_accuracy: 0.4200
Epoch 4/20
31/31 [==============================] - 0s 7ms/step - loss: 1.7967 - accuracy: 0.3778 - val_loss: 1.6403 - val_accuracy: 0.4400
Epoch 5/20
31/31 [==============================] - 0s 6ms/step - loss: 1.6455 - accuracy: 0.4475 - val_loss: 1.5402 - val_accuracy: 0.5400
Epoch 6/20
31/31 [==============================] - 0s 6ms/step - loss: 1.5589 - accuracy: 0.4636 - val_loss: 1.4450 - val_accuracy: 0.5400
Epoch 7/20
31/31 [==============================] - 0s 6ms/step - loss: 1.4431 - accuracy: 0.5141 - val_loss: 1.3927 - val_accuracy: 0.5700
Epoch 8/20
31/31 [==============================] - 0s 5ms/step - loss: 1.3231 - accuracy: 0.5586 - val_loss: 1.3693 - val_accuracy: 0.5800
Epoch 9/20
31/31 [==============================] - 0s 5ms/step - loss: 1.2217 - accuracy: 0.6081 - val_loss: 1.3443 - val_accuracy: 0.5900
Epoch 10/20
31/31 [==============================] - 0s 6ms/step - loss: 1.0852 - accuracy: 0.6495 - val_loss: 1.3437 - val_accuracy: 0.6200
Epoch 11/20
31/31 [==============================] - 0s 7ms/step - loss: 1.0130 - accuracy: 0.6697 - val_loss: 1.3040 - val_accuracy: 0.5800
Epoch 12/20
31/31 [==============================] - 0s 6ms/step - loss: 0.9402 - accuracy: 0.6778 - val_loss: 1.2952 - val_accuracy: 0.5600
Epoch 13/20
31/31 [==============================] - 0s 7ms/step - loss: 0.8554 - accuracy: 0.7101 - val_loss: 1.2352 - val_accuracy: 0.6000
Epoch 14/20
31/31 [==============================] - 0s 7ms/step - loss: 0.7196 - accuracy: 0.7636 - val_loss: 1.2176 - val_accuracy: 0.6300
Epoch 15/20
31/31 [==============================] - 0s 7ms/step - loss: 0.6575 - accuracy: 0.7828 - val_loss: 1.2097 - val_accuracy: 0.5900
Epoch 16/20
31/31 [==============================] - 0s 6ms/step - loss: 0.6042 - accuracy: 0.8111 - val_loss: 1.2197 - val_accuracy: 0.6600
Epoch 17/20
31/31 [==============================] - 0s 7ms/step - loss: 0.5317 - accuracy: 0.8364 - val_loss: 1.3480 - val_accuracy: 0.6300
Epoch 18/20
31/31 [==============================] - 0s 7ms/step - loss: 0.4769 - accuracy: 0.8535 - val_loss: 1.3447 - val_accuracy: 0.6100
Epoch 19/20
31/31 [==============================] - 0s 7ms/step - loss: 0.4356 - accuracy: 0.8657 - val_loss: 1.3613 - val_accuracy: 0.6100
Epoch 20/20
31/31 [==============================] - 0s 6ms/step - loss: 0.3632 - accuracy: 0.8949 - val_loss: 1.3563 - val_accuracy: 0.6300

Plotting the training curves¶

In [46]:
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(multiclass_cnn_model_history.history['accuracy'], label='Training Accuracy')
plt.plot(multiclass_cnn_model_history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(multiclass_cnn_model_history.history['loss'], label='Training Loss')
plt.plot(multiclass_cnn_model_history.history['val_loss'], label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()
No description has been provided for this image

Predicting on the unseen data using CNN¶

Just like before, this is expected to give answers from existing training labels, but we have to measure this uncertainty

In [47]:
def predict_unseen_class_with_cnn(model, data, threshold=0.6):
    cnn_predictions = model.predict(data)
    cnn_max_probs = np.max(cnn_predictions, axis=1)
    
    cnn_predicted_classes = np.argmax(cnn_predictions, axis=1)
    cnn_predicted_labels = [cnn_predicted_classes[i] if cnn_max_probs[i] > threshold else CLASS_COUNT 
                        for i in range(len(cnn_predictions))]
    
    return cnn_predicted_labels, cnn_max_probs
In [48]:
cnn_predicted_labels,cnn_max_probabilities = predict_unseen_class_with_cnn(multiclass_cnn_model, unseen_train_features_pca)

print("Predicted Labels for Unseen Train Data: ", cnn_predicted_labels)
print("Max Probabilities for Each Sample: ", cnn_max_probabilities)

cnn_unseen_accuracy = np.sum(np.array(cnn_predicted_labels) == np.array(unseen_train_labels)) / len(cnn_predicted_labels)
print(f"Unseen Test Accuracy (CNN model): {cnn_unseen_accuracy * 100:.2f}%")
4/4 [==============================] - 0s 48ms/step
Predicted Labels for Unseen Train Data:  [3, 10, 10, 10, 10, 0, 10, 2, 10, 9, 8, 10, 7, 10, 3, 10, 3, 10, 0, 10, 10, 10, 1, 3, 10, 0, 8, 10, 0, 10, 5, 10, 4, 8, 3, 0, 0, 8, 0, 10, 10, 7, 7, 10, 3, 7, 10, 0, 5, 7, 7, 3, 10, 8, 5, 10, 10, 10, 5, 3, 10, 0, 10, 10, 3, 9, 10, 9, 8, 3, 10, 7, 10, 4, 6, 10, 2, 6, 5, 10, 3, 5, 10, 8, 8, 10, 10, 7, 10, 10, 0, 6, 3, 3, 10, 9, 10, 0, 6]
Max Probabilities for Each Sample:  [0.9064329  0.45054948 0.5946742  0.5511862  0.3530587  0.87295985
 0.51526433 0.8329291  0.45422179 0.97006536 0.7646488  0.38729084
 0.84711635 0.4400103  0.9416331  0.30798355 0.795604   0.28646448
 0.91080946 0.5087042  0.52449185 0.45686173 0.88511413 0.9911891
 0.54418766 0.9676108  0.998692   0.34801716 0.92508197 0.3443876
 0.7578489  0.34213483 0.63492376 0.96395475 0.6586164  0.6747164
 0.9903346  0.7748831  0.7374502  0.4587058  0.43695304 0.9827419
 0.82170707 0.25135395 0.9580979  0.9862872  0.49005967 0.6529061
 0.82863    0.8018886  0.81266123 0.93588287 0.49787262 0.78546125
 0.9597415  0.42318228 0.31291202 0.2658294  0.6479913  0.94912845
 0.4742788  0.75979775 0.48471788 0.441836   0.9092515  0.94357085
 0.5765634  0.8062707  0.6396273  0.9241755  0.58979356 0.69089156
 0.59008497 0.793422   0.71004087 0.58620185 0.98395175 0.6314479
 0.9421464  0.44288418 0.9614776  0.70739955 0.521958   0.6656614
 0.77204394 0.39640325 0.3153098  0.7565069  0.54067516 0.49120516
 0.86622316 0.74224186 0.755583   0.77655226 0.58628523 0.85986567
 0.4727107  0.98942995 0.9923171 ]
Unseen Test Accuracy (CNN model): 40.40%

Nonlinear Regression Problem¶

Similar comparison problem as above, except, for a more complex regression problem. Purpose is to detect whether an input is in the extrapolation region, or not.

An extrapolation region occurs when the model encounters data that significantly deviates from the data it was trained on. This is important, because for such problems, this is allows us to recognize when a model might be making unreliable predictions due to a lack of information or prior exposure.

Data exploration¶

California Housing Prices

In [6]:
housing_dataframe = pd.read_csv(regression_data_directory)
housing_dataframe.head()
Out[6]:
longitude latitude housing_median_age total_rooms total_bedrooms population households median_income median_house_value ocean_proximity
0 -122.23 37.88 41.0 880.0 129.0 322.0 126.0 8.3252 452600.0 NEAR BAY
1 -122.22 37.86 21.0 7099.0 1106.0 2401.0 1138.0 8.3014 358500.0 NEAR BAY
2 -122.24 37.85 52.0 1467.0 190.0 496.0 177.0 7.2574 352100.0 NEAR BAY
3 -122.25 37.85 52.0 1274.0 235.0 558.0 219.0 5.6431 341300.0 NEAR BAY
4 -122.25 37.85 52.0 1627.0 280.0 565.0 259.0 3.8462 342200.0 NEAR BAY
In [7]:
housing_dataframe.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20640 entries, 0 to 20639
Data columns (total 10 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   longitude           20640 non-null  float64
 1   latitude            20640 non-null  float64
 2   housing_median_age  20640 non-null  float64
 3   total_rooms         20640 non-null  float64
 4   total_bedrooms      20433 non-null  float64
 5   population          20640 non-null  float64
 6   households          20640 non-null  float64
 7   median_income       20640 non-null  float64
 8   median_house_value  20640 non-null  float64
 9   ocean_proximity     20640 non-null  object 
dtypes: float64(9), object(1)
memory usage: 1.6+ MB
In [8]:
label_encoder = LabelEncoder()
housing_dataframe['ocean_proximity'] = label_encoder.fit_transform(housing_dataframe['ocean_proximity'])
housing_dataframe.head()
Out[8]:
longitude latitude housing_median_age total_rooms total_bedrooms population households median_income median_house_value ocean_proximity
0 -122.23 37.88 41.0 880.0 129.0 322.0 126.0 8.3252 452600.0 3
1 -122.22 37.86 21.0 7099.0 1106.0 2401.0 1138.0 8.3014 358500.0 3
2 -122.24 37.85 52.0 1467.0 190.0 496.0 177.0 7.2574 352100.0 3
3 -122.25 37.85 52.0 1274.0 235.0 558.0 219.0 5.6431 341300.0 3
4 -122.25 37.85 52.0 1627.0 280.0 565.0 259.0 3.8462 342200.0 3

Correlation between median_house_value and rest of the features¶

It makes sense to see which features impact the target the most. Luckily, there are no samples that have null values on any column, so we can have a clean comparison. We will find that features such as median_income, housing_median_age, or total_rooms will have more impact on our target than the rest, and their distribution will matter

In [9]:
plt.figure(figsize=(10, 8))
sns.heatmap(housing_dataframe.corr(), annot=True, cmap='coolwarm', fmt='.2f')
plt.show()
No description has been provided for this image

Check the distribution on each feature¶

In [10]:
numeric_columns = housing_dataframe.select_dtypes(include=['float64', 'int64']).columns

plot_number_columns = 3  
plot_number_rows = (len(numeric_columns) + plot_number_columns - 1) // plot_number_columns  
plt.figure(figsize=(15, plot_number_rows * 4))  

for i, column in enumerate(numeric_columns, 1):
    plt.subplot(plot_number_rows, plot_number_columns, i)  
    plt.hist(housing_dataframe[column], bins=30, alpha=0.7, color='blue', edgecolor='black')
    plt.title(column) 
    plt.xlabel(column)
    plt.ylabel('Frequency')

plt.tight_layout()
plt.show()
No description has been provided for this image

Split data into features and target values¶

Our target value is median_house_value column of the dataframe.

In [11]:
features = housing_dataframe.drop(columns=['median_house_value'])  
targets = housing_dataframe['median_house_value'].values 
features.shape, targets.shape
Out[11]:
((20640, 9), (20640,))
In [12]:
print("Mean Vector of features: ", np.mean(features))
print("Mean of target prices: ", np.mean(targets))
print("Standard Deviation vector of features: ", np.std(features))
print("Standard Deviation of target prices: ", np.std(targets))
Mean Vector of features:  560.957750990283
Mean of target prices:  206855.81690891474
Standard Deviation vector of features:  longitude                2.003483
latitude                 2.135901
housing_median_age      12.585253
total_rooms           2181.562402
total_bedrooms         421.374759
population            1132.434688
households             382.320491
median_income            1.899776
ocean_proximity          1.420628
dtype: float64
Standard Deviation of target prices:  115392.82040412253
D:\Programs\Anaconda\envs\testing_env\lib\site-packages\numpy\core\fromnumeric.py:3643: FutureWarning: The behavior of DataFrame.std with axis=None is deprecated, in a future version this will reduce over both axes and return a scalar. To retain the old behavior, pass axis=0 (or do not pass axis)
  return std(axis=axis, dtype=dtype, out=out, ddof=ddof, **kwargs)

Normalize Features (opt out for now to test changes on predictions)¶

Prediction values were normalized as well and close to 0, so I'm testing without this.

In [13]:
feature_scaler = StandardScaler()
target_scaler = StandardScaler()

features_normalized = feature_scaler.fit_transform(features)
targets_normalized = target_scaler.fit_transform(targets.reshape(-1, 1))

print("Features shape:", features_normalized.shape)
print("Targets shape:", targets_normalized.shape)
Features shape: (20640, 9)
Targets shape: (20640, 1)

Use just a fraction of data (for performance and time issues)¶

In [14]:
subset_fraction = 0.1 
subset_features, _, subset_targets, _ = train_test_split(
    features_normalized, targets_normalized, 
    test_size=(1 - subset_fraction), 
    random_state=42
)
subset_features.shape, subset_targets.shape
Out[14]:
((2064, 9), (2064, 1))

Split data into training and testing¶

In [15]:
training_features, testing_features, training_targets, testing_targets = train_test_split(subset_features, subset_targets, test_size=0.2, random_state=42)
training_features.shape, testing_features.shape, training_targets.shape, testing_targets.shape
Out[15]:
((1651, 9), (413, 9), (1651, 1), (413, 1))
In [16]:
print("Mean of training features: ", np.mean(training_features))
print("Mean of testing features: ", np.mean(testing_features))
print("Mean of training targets: ", np.mean(training_targets))
print("Mean of testing targets: ", np.mean(testing_targets))
print("Standard Deviation of training features: ", np.std(training_features))
print("Standard Deviation of testing features: ", np.std(testing_features))
print("Standard Deviation of training targets: ", np.std(training_targets))
print("Standard Deviation of testing targets: ", np.std(testing_targets))
Mean of training features:  -0.0018504444401983612
Mean of testing features:  0.01442004947961341
Mean of training targets:  -0.018936552193470953
Mean of testing targets:  0.06669435665251733
Standard Deviation of training features:  1.035734405096662
Standard Deviation of testing features:  0.9603824426214513
Standard Deviation of training targets:  0.9974454291523375
Standard Deviation of testing targets:  1.0503132232872654

Bayesian Neural Network for Regression - Architecture¶

In [17]:
def regression_bayesian_network(features_params, targets_params, n_hidden=32, prior_std=1.0, likelihood_std=1.0):
    with pm.Model() as regression_bnn_model:
        features_data = pm.Data("features_params", features_params)
        targets_data = pm.Data("targets_params", targets_params)

        weights_1 = pm.Normal("weights_1", mu=0, sigma=prior_std, shape=(features_params.shape[1], n_hidden))
        bias_1 = pm.Normal("bias_1", mu=0, sigma=prior_std, shape=(n_hidden,))

        hidden_1 = pt.tanh(pt.dot(features_data, weights_1) + bias_1)

        weights_2 = pm.Normal("weights_2", mu=0, sigma=prior_std, shape=(n_hidden, n_hidden))
        bias_2 = pm.Normal("bias_2", mu=0, sigma=prior_std, shape=(n_hidden,))
        
        hidden_2 = pt.tanh(pt.dot(hidden_1, weights_2) + bias_2)

        weights_out = pm.Normal("weights_out", mu=0, sigma=prior_std, shape=(n_hidden,))
        bias_out = pm.Normal("bias_out", mu=0, sigma=prior_std)

        raw_output = pt.dot(hidden_2, weights_out) + bias_out
        output = pm.Deterministic("output", pt.expand_dims(raw_output, axis=1))

        sigma = pm.HalfNormal("sigma", sigma=likelihood_std)
        likelihood = pm.Normal("likelihood", mu=output, sigma=sigma, observed=targets_data)
        likelihood_output = pm.Deterministic("likelihood_output", likelihood)
        
        trace = pm.sample(
            1000,
            return_inferencedata=True,
            tune=4000,
            progressbar=True,
            idata_kwargs={"log_likelihood": True},
            target_accept=0.95
        )

    return regression_bnn_model, trace
In [19]:
regression_bayesian_network_model, regression_bayesian_network_trace = regression_bayesian_network(
    training_features,
    training_targets,
    n_hidden=32,
    prior_std=1.0,
    likelihood_std=np.std(training_targets)
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [weights_1, bias_1, weights_2, bias_2, weights_out, bias_out, sigma]
Output()

Sampling 4 chains for 4_000 tune and 1_000 draw iterations (16_000 + 4_000 draws total) took 24355 seconds.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
In [64]:
fig = az.plot_trace(regression_bayesian_network_trace, figsize=(10, 20))
plt.subplots_adjust(hspace=0.5) 
plt.show()
D:\Programs\Anaconda\envs\testing_env\lib\site-packages\arviz\stats\density_utils.py:488: UserWarning: Your data appears to have a single value or no finite values
  warnings.warn("Your data appears to have a single value or no finite values")
No description has been provided for this image

Evaluating it on the training set¶

In [177]:
with regression_bayesian_network_model:
    posterior_pred_train = pm.sample_posterior_predictive(regression_bayesian_network_trace, var_names=["likelihood_output", "sigma"])

posterior_pred_train.posterior_predictive
Sampling: [likelihood, sigma]
Output()

Out[177]:
<xarray.Dataset> Size: 13MB
Dimensions:                  (chain: 4, draw: 1000,
                              likelihood_output_dim_2: 413,
                              likelihood_output_dim_3: 1)
Coordinates:
  * chain                    (chain) int32 16B 0 1 2 3
  * draw                     (draw) int32 4kB 0 1 2 3 4 ... 995 996 997 998 999
  * likelihood_output_dim_2  (likelihood_output_dim_2) int32 2kB 0 1 ... 411 412
  * likelihood_output_dim_3  (likelihood_output_dim_3) int32 4B 0
Data variables:
    likelihood_output        (chain, draw, likelihood_output_dim_2, likelihood_output_dim_3) float64 13MB ...
    sigma                    (chain, draw) float64 32kB 1.286 0.2617 ... 0.6951
Attributes:
    created_at:                 2025-01-24T17:27:55.165502+00:00
    arviz_version:              0.20.0
    inference_library:          pymc
    inference_library_version:  5.20.0
xarray.Dataset
    • chain: 4
    • draw: 1000
    • likelihood_output_dim_2: 413
    • likelihood_output_dim_3: 1
    • chain
      (chain)
      int32
      0 1 2 3
      array([0, 1, 2, 3])
    • draw
      (draw)
      int32
      0 1 2 3 4 5 ... 995 996 997 998 999
      array([  0,   1,   2, ..., 997, 998, 999])
    • likelihood_output_dim_2
      (likelihood_output_dim_2)
      int32
      0 1 2 3 4 5 ... 408 409 410 411 412
      array([  0,   1,   2, ..., 410, 411, 412])
    • likelihood_output_dim_3
      (likelihood_output_dim_3)
      int32
      0
      array([0])
    • likelihood_output
      (chain, draw, likelihood_output_dim_2, likelihood_output_dim_3)
      float64
      -1.412 2.28 ... -0.8485 -0.7315
      array([[[[-1.41241064],
               [ 2.28035149],
               [ 1.54823461],
               ...,
               [-2.30990246],
               [-2.56092211],
               [ 0.63696651]],
      
              [[ 1.15946122],
               [ 1.59058408],
               [-1.36318321],
               ...,
               [-1.57407805],
               [-1.50829381],
               [-0.70420311]],
      
              [[-0.60385644],
               [ 0.73045974],
               [ 1.01282179],
               ...,
      ...
               ...,
               [-1.84945787],
               [-0.62645607],
               [-0.12300385]],
      
              [[ 0.46312584],
               [ 3.04067613],
               [-0.26682857],
               ...,
               [-1.88175319],
               [-1.25099215],
               [-0.33150492]],
      
              [[ 0.9152581 ],
               [ 1.79810274],
               [-1.98671839],
               ...,
               [-1.20646653],
               [-0.84854195],
               [-0.73146077]]]])
    • sigma
      (chain, draw)
      float64
      1.286 0.2617 ... 0.4114 0.6951
      array([[1.28593354, 0.26170184, 1.30168443, ..., 0.29374634, 0.57141946,
              0.91031735],
             [0.90169945, 0.02535657, 1.80845014, ..., 0.48799166, 0.7803237 ,
              0.15335775],
             [0.6724762 , 0.66774634, 1.37737704, ..., 0.29979882, 0.21049014,
              0.50576691],
             [0.13415786, 2.25767791, 1.00507777, ..., 0.6532311 , 0.41137022,
              0.69510895]])
    • chain
      PandasIndex
      PandasIndex(Index([0, 1, 2, 3], dtype='int32', name='chain'))
    • draw
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
            dtype='int32', name='draw', length=1000))
    • likelihood_output_dim_2
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             403, 404, 405, 406, 407, 408, 409, 410, 411, 412],
            dtype='int32', name='likelihood_output_dim_2', length=413))
    • likelihood_output_dim_3
      PandasIndex
      PandasIndex(Index([0], dtype='int32', name='likelihood_output_dim_3'))
  • created_at :
    2025-01-24T17:27:55.165502+00:00
    arviz_version :
    0.20.0
    inference_library :
    pymc
    inference_library_version :
    5.20.0
In [176]:
regression_bayesian_network_trace
Out[176]:
arviz.InferenceData
    • <xarray.Dataset> Size: 151MB
      Dimensions:                  (chain: 4, draw: 1000, weights_1_dim_0: 9,
                                    weights_1_dim_1: 32, bias_1_dim_0: 32,
                                    weights_2_dim_0: 32, weights_2_dim_1: 32,
                                    bias_2_dim_0: 32, weights_out_dim_0: 32,
                                    output_dim_0: 1651, output_dim_1: 1,
                                    likelihood_output_dim_0: 1651,
                                    likelihood_output_dim_1: 1)
      Coordinates: (12/13)
        * chain                    (chain) int32 16B 0 1 2 3
        * draw                     (draw) int32 4kB 0 1 2 3 4 ... 995 996 997 998 999
        * weights_1_dim_0          (weights_1_dim_0) int32 36B 0 1 2 3 4 5 6 7 8
        * weights_1_dim_1          (weights_1_dim_1) int32 128B 0 1 2 3 ... 29 30 31
        * bias_1_dim_0             (bias_1_dim_0) int32 128B 0 1 2 3 4 ... 28 29 30 31
        * weights_2_dim_0          (weights_2_dim_0) int32 128B 0 1 2 3 ... 29 30 31
          ...                       ...
        * bias_2_dim_0             (bias_2_dim_0) int32 128B 0 1 2 3 4 ... 28 29 30 31
        * weights_out_dim_0        (weights_out_dim_0) int32 128B 0 1 2 3 ... 29 30 31
        * output_dim_0             (output_dim_0) int32 7kB 0 1 2 3 ... 1648 1649 1650
        * output_dim_1             (output_dim_1) int32 4B 0
        * likelihood_output_dim_0  (likelihood_output_dim_0) int32 7kB 0 1 ... 1650
        * likelihood_output_dim_1  (likelihood_output_dim_1) int32 4B 0
      Data variables:
          weights_1                (chain, draw, weights_1_dim_0, weights_1_dim_1) float64 9MB ...
          bias_1                   (chain, draw, bias_1_dim_0) float64 1MB -1.288 ....
          weights_2                (chain, draw, weights_2_dim_0, weights_2_dim_1) float64 33MB ...
          bias_2                   (chain, draw, bias_2_dim_0) float64 1MB -0.1325 ...
          weights_out              (chain, draw, weights_out_dim_0) float64 1MB 0.1...
          bias_out                 (chain, draw) float64 32kB 1.223 0.9233 ... 0.4363
          sigma                    (chain, draw) float64 32kB 0.1625 0.1663 ... 0.1879
          output                   (chain, draw, output_dim_0, output_dim_1) float64 53MB ...
          likelihood_output        (chain, draw, likelihood_output_dim_0, likelihood_output_dim_1) float64 53MB ...
      Attributes:
          created_at:                 2025-01-24T04:07:53.041500+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0
          sampling_time:              24354.77615261078
          tuning_steps:               4000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • weights_1_dim_0: 9
        • weights_1_dim_1: 32
        • bias_1_dim_0: 32
        • weights_2_dim_0: 32
        • weights_2_dim_1: 32
        • bias_2_dim_0: 32
        • weights_out_dim_0: 32
        • output_dim_0: 1651
        • output_dim_1: 1
        • likelihood_output_dim_0: 1651
        • likelihood_output_dim_1: 1
        • chain
          (chain)
          int32
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int32
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • weights_1_dim_0
          (weights_1_dim_0)
          int32
          0 1 2 3 4 5 6 7 8
          array([0, 1, 2, 3, 4, 5, 6, 7, 8])
        • weights_1_dim_1
          (weights_1_dim_1)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • bias_1_dim_0
          (bias_1_dim_0)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • weights_2_dim_0
          (weights_2_dim_0)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • weights_2_dim_1
          (weights_2_dim_1)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • bias_2_dim_0
          (bias_2_dim_0)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • weights_out_dim_0
          (weights_out_dim_0)
          int32
          0 1 2 3 4 5 6 ... 26 27 28 29 30 31
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
        • output_dim_0
          (output_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • output_dim_1
          (output_dim_1)
          int32
          0
          array([0])
        • likelihood_output_dim_0
          (likelihood_output_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • likelihood_output_dim_1
          (likelihood_output_dim_1)
          int32
          0
          array([0])
        • weights_1
          (chain, draw, weights_1_dim_0, weights_1_dim_1)
          float64
          -1.029 -1.305 ... -1.266 1.97
          array([[[[-1.02887546e+00, -1.30511864e+00, -2.00764796e+00, ...,
                     2.71019019e+00,  9.70363035e-01, -9.62098726e-02],
                   [-7.96194176e-01,  3.45744070e+00, -1.66174413e+00, ...,
                    -2.16980547e+00,  1.07881396e+00, -3.33437739e-01],
                   [-5.47680950e-01,  2.50187523e-01, -1.07422873e+00, ...,
                    -2.26271158e-02, -6.08386734e-01, -5.10778324e-01],
                   ...,
                   [ 5.32981506e-01,  1.98183540e-01, -1.78107755e+00, ...,
                     8.85181058e-01,  5.40680404e-01,  8.28105168e-01],
                   [ 3.02443304e+00,  4.63748125e-01,  1.80043080e+00, ...,
                    -1.72184775e-01, -7.45711769e-01, -1.27928725e+00],
                   [-2.26979964e+00,  4.85435274e-01, -8.75587330e-01, ...,
                    -3.46449491e+00, -7.81556027e-01, -4.58360596e-01]],
          
                  [[-1.20684790e+00, -1.42507531e+00, -1.88013899e+00, ...,
                     2.96398001e+00,  6.02461324e-01,  9.13088568e-02],
                   [-1.11604229e+00,  3.05997047e+00, -1.52636151e+00, ...,
                    -1.41613889e+00,  7.40897154e-01, -5.87727970e-02],
                   [-6.45426891e-01,  2.33269967e-01, -9.51348214e-01, ...,
                     2.44240082e-01, -5.87636353e-01, -4.55479992e-01],
          ...
                   [ 2.97241900e+00,  1.01496677e+00,  2.65321654e-02, ...,
                    -4.22098584e-01,  1.73006268e+00,  4.59933637e-02],
                   [-1.13081122e+00,  1.72822958e-01,  8.42791521e-02, ...,
                    -4.89529142e-01, -2.07200146e+00,  1.24919807e-01],
                   [ 6.11602228e-01, -6.55239701e-01, -5.94351809e-01, ...,
                     7.60745316e-02, -1.20986570e+00,  1.73029676e+00]],
          
                  [[-3.46087939e-01,  1.77747868e+00,  1.55618411e+00, ...,
                    -1.33019729e-01,  1.37285342e-01, -6.69640529e-01],
                   [-5.98120073e-01, -2.25222998e+00, -1.72709614e+00, ...,
                     1.33656018e-01, -2.75264491e+00, -6.10675194e-01],
                   [ 1.94072052e-01, -1.50952451e-01,  4.83781355e-01, ...,
                    -4.51099290e-02, -7.94847672e-01, -1.59077209e+00],
                   ...,
                   [ 2.20483467e+00,  1.26963849e-01, -6.35936236e-02, ...,
                     1.19818977e-01,  1.12475500e+00, -3.65613983e-01],
                   [-9.72984043e-01, -3.62889593e-01, -3.27743272e-01, ...,
                    -3.74134872e-01, -2.42702224e+00,  3.46795727e-01],
                   [ 5.67121159e-01, -6.86500014e-01, -6.47678615e-01, ...,
                     1.10044427e-01, -1.26607088e+00,  1.97022125e+00]]]])
        • bias_1
          (chain, draw, bias_1_dim_0)
          float64
          -1.288 2.047 ... -0.1663 0.8861
          array([[[-1.28768598,  2.0469208 ,  1.54706611, ...,  1.54880668,
                    1.83555331,  0.95648941],
                  [-1.14118473,  2.02478717,  1.05782446, ...,  1.54509924,
                    1.96556698,  1.42342517],
                  [-0.99989881,  1.31332576,  1.42131654, ...,  1.56316338,
                    2.44053918,  1.21801122],
                  ...,
                  [-0.52489967,  1.31119918,  1.58649579, ...,  1.28528799,
                    1.47678196,  1.89426675],
                  [-0.47792224,  1.25177826,  1.61229938, ...,  1.33654103,
                    1.59074621,  1.82598694],
                  [-0.59933958,  1.61942321,  1.57534446, ...,  1.13275719,
                    1.3310417 ,  1.76677954]],
          
                 [[ 1.6481139 ,  1.60140775,  0.62359312, ...,  0.83536304,
                   -0.33985658, -0.65162545],
                  [ 1.88511644,  1.88987838,  0.70938942, ...,  0.82017141,
                   -0.32733383, -0.36478001],
                  [ 1.61220105,  1.82733008,  0.66566292, ...,  0.87325828,
                   -0.18987125, -0.70766511],
          ...
                  [ 0.44935874, -1.84341838, -0.86455855, ..., -0.85887348,
                   -2.07152706, -1.22805049],
                  [ 0.42456583, -1.96221773, -0.90816208, ..., -0.70513478,
                   -2.39564297, -1.09999995],
                  [ 0.42711069, -1.8627964 , -0.89051905, ..., -0.82301377,
                   -2.56147915, -1.06844313]],
          
                 [[ 2.42370757,  2.71757208, -2.9480667 , ...,  0.94363625,
                   -0.26020862,  1.30318776],
                  [ 2.70139101,  2.45848816, -2.5280103 , ...,  0.92300257,
                   -0.19442786,  1.54980472],
                  [ 2.63721338,  2.60389563, -3.36936129, ...,  0.86677748,
                   -0.36485383,  2.14116363],
                  ...,
                  [ 3.64027744, -0.99512364, -4.08495379, ...,  0.7879665 ,
                   -0.20679885,  0.84193336],
                  [ 3.25738556, -1.71613754, -3.81037064, ...,  0.82716542,
                   -0.26220219,  0.75776334],
                  [ 3.3510687 , -1.12140705, -3.16248783, ...,  0.68417397,
                   -0.16625022,  0.88612746]]])
        • weights_2
          (chain, draw, weights_2_dim_0, weights_2_dim_1)
          float64
          -0.2875 0.9567 ... 0.4473 -0.1422
          array([[[[-2.87549506e-01,  9.56739782e-01,  1.24705583e+00, ...,
                    -4.04081723e-01,  1.31048494e+00, -1.74788814e+00],
                   [-8.75449885e-01,  1.96053286e+00, -2.88894407e-01, ...,
                     5.34626445e-01, -2.73795239e-01, -1.01742205e-01],
                   [ 5.03948336e-01,  5.55341328e-01, -3.95095523e-01, ...,
                    -4.18671560e-01,  6.48790040e-01, -2.92556040e+00],
                   ...,
                   [ 5.58339940e-01,  1.64912621e+00, -1.81592068e+00, ...,
                     1.62185007e+00, -6.48247940e-02,  1.16700561e+00],
                   [ 6.96538735e-02,  9.50832368e-01,  8.82806325e-02, ...,
                     6.03245719e-01,  4.88030584e-02, -1.63161283e-01],
                   [ 8.95260998e-01,  1.44357173e+00, -4.47628663e-01, ...,
                     4.66806208e-02,  7.47074309e-02,  6.39487537e-01]],
          
                  [[ 4.46558402e-01,  1.20031933e+00, -3.14454514e-01, ...,
                     4.54401410e-01,  1.79289209e+00, -9.53156842e-01],
                   [ 1.03700899e+00,  1.93347864e+00,  5.44711268e-01, ...,
                     1.27602852e+00, -1.93447960e-01,  1.69494190e-02],
                   [ 5.23559805e-01,  1.24439588e+00,  3.91905835e-01, ...,
                    -8.99869306e-01,  1.18249171e+00, -2.33271580e+00],
          ...
                   [-2.90991511e-01, -2.08044020e-01,  9.39159645e-01, ...,
                     1.24606818e+00,  3.26498401e-01,  8.59708073e-02],
                   [-2.48511087e-01,  6.72516053e-01,  2.11764982e+00, ...,
                    -1.31353938e-01,  5.85649450e-01,  4.24317060e-01],
                   [ 7.04752981e-01, -2.29167253e-01,  1.07209793e+00, ...,
                     5.12340861e-02,  8.47465607e-01, -7.01347768e-01]],
          
                  [[-1.28764975e+00, -8.35073059e-01, -6.64362193e-01, ...,
                    -1.20128282e+00, -5.81351122e-01,  1.31320594e+00],
                   [-8.97625444e-02, -7.26254239e-01, -1.06627973e+00, ...,
                     5.10688037e-01, -2.08844082e+00,  1.16397389e+00],
                   [-9.36362724e-01, -1.27099792e+00,  7.97109440e-01, ...,
                     1.14721004e+00, -7.21472259e-01, -1.98364521e+00],
                   ...,
                   [-1.45693356e+00,  9.75396862e-01,  1.21972764e+00, ...,
                    -6.28748094e-01, -2.67660731e-01, -3.85300007e-02],
                   [ 4.58346868e-01, -1.07873602e+00, -1.10528713e-01, ...,
                    -7.93441531e-01,  1.63244415e-01,  4.60291984e-01],
                   [-4.04155954e-01,  9.77305290e-01, -3.63258721e-01, ...,
                    -2.16730605e-01,  4.47329198e-01, -1.42243885e-01]]]])
        • bias_2
          (chain, draw, bias_2_dim_0)
          float64
          -0.1325 1.354 ... -0.7249 -0.5419
          array([[[-1.32536620e-01,  1.35360592e+00, -1.05632620e+00, ...,
                   -5.87826871e-01, -2.22589438e-01,  2.02415459e+00],
                  [ 1.75993762e+00,  1.53553617e+00, -4.24172299e-01, ...,
                    8.61005688e-01, -7.52615426e-01, -2.81856266e-01],
                  [-1.01951927e-01,  1.71996993e+00,  1.41385421e+00, ...,
                   -1.29975837e-01,  7.12667690e-01,  1.52320840e+00],
                  ...,
                  [-1.54277816e+00, -1.89289817e+00, -1.40026060e-01, ...,
                   -1.84965364e+00,  1.43653596e+00,  5.16007261e-01],
                  [-9.09485225e-01, -1.49681449e+00, -7.49167716e-02, ...,
                   -2.72654521e+00,  2.36116058e+00,  1.01044622e+00],
                  [-9.04142788e-01, -1.13852783e+00,  2.56967526e-02, ...,
                   -1.90745262e+00,  5.73277110e-01,  6.31559667e-01]],
          
                 [[ 4.67081128e-01, -4.48786598e-01,  1.56145897e+00, ...,
                    2.31606772e+00,  1.53925965e+00, -6.88341437e-01],
                  [ 7.38297231e-01,  4.63944447e-01,  2.66786037e+00, ...,
                    2.12943595e+00,  1.77415254e-01,  4.48594535e-01],
                  [ 2.61180327e-01,  8.46562704e-01,  9.22201251e-01, ...,
                    1.21442758e+00, -4.74371259e-01, -1.71004867e+00],
          ...
                  [ 2.10462318e-01, -1.12929240e+00,  1.36837021e+00, ...,
                   -8.51215205e-01, -1.94618349e-01, -4.36815206e-02],
                  [-6.58401714e-01,  4.26406736e-01,  2.16149909e+00, ...,
                   -1.18086644e+00, -4.30883850e-01, -5.50777955e-01],
                  [-7.12676393e-01, -3.09037541e-01,  2.11500585e+00, ...,
                    1.63076133e-01, -4.51793320e-01, -1.37470051e+00]],
          
                 [[-1.36844789e+00, -8.90089787e-01, -5.13594516e-01, ...,
                   -4.42484110e-01, -1.60599206e+00, -8.54413675e-01],
                  [ 4.47139910e-01, -4.76357343e-01, -2.32409162e-01, ...,
                    1.36396697e-01,  5.92489797e-01, -1.43396097e+00],
                  [ 2.32706475e-01, -2.06822699e+00, -1.02832916e-01, ...,
                    4.37863302e-02, -1.59325844e-02, -5.44275800e-01],
                  ...,
                  [-6.42463930e-01, -1.27739919e-01,  4.64508338e-01, ...,
                    3.73983102e-01, -9.35835557e-01, -6.32212358e-01],
                  [-1.73804535e+00,  8.01580734e-01, -2.71290138e-01, ...,
                   -1.02065001e+00, -1.11713951e+00, -5.30256506e-01],
                  [-1.33697302e-01, -1.45258422e+00,  1.80334341e-01, ...,
                    1.16225173e-01, -7.24875613e-01, -5.41883283e-01]]])
        • weights_out
          (chain, draw, weights_out_dim_0)
          float64
          0.1761 -1.15 ... -0.33 1.299
          array([[[ 1.76099799e-01, -1.15043519e+00, -1.24894932e-02, ...,
                   -2.40321021e-01, -1.66320751e-01, -3.12017859e-01],
                  [ 1.79437847e-01, -1.10238380e+00, -6.68353337e-03, ...,
                   -1.89656844e-01, -1.63990708e-01, -3.72307798e-01],
                  [ 1.02351294e-01, -1.11833643e+00, -7.22277014e-03, ...,
                   -2.37464334e-01, -2.13841207e-01, -3.17845265e-01],
                  ...,
                  [-1.30178200e-01, -1.06617262e+00, -9.24904232e-04, ...,
                   -1.92980068e-01, -2.21195225e-02, -3.34134388e-01],
                  [-7.12910362e-02, -1.01376286e+00, -1.91186550e-02, ...,
                   -1.72110252e-01,  9.27816731e-04, -3.35469768e-01],
                  [-1.17564102e-01, -1.09701495e+00,  4.28370259e-03, ...,
                   -1.86051013e-01, -2.49708207e-03, -3.07724808e-01]],
          
                 [[ 4.80216044e-02, -4.46339296e-01,  9.96859269e-02, ...,
                    2.49498899e-01,  7.51050956e-02, -6.99855418e-01],
                  [ 1.05211759e-01, -4.50251056e-01,  8.38846563e-02, ...,
                    1.71512915e-01,  8.23126453e-02, -7.31662512e-01],
                  [ 8.19343788e-02, -4.27153351e-01, -4.69156476e-02, ...,
                    2.28872343e-01,  1.24483991e-01, -9.18272766e-01],
          ...
                  [-1.07504806e-01,  9.29287625e-02,  4.30590752e-01, ...,
                   -8.92633683e-01,  2.26869751e+00, -4.15806294e-02],
                  [-5.38294302e-02,  1.17670751e-01,  4.08357465e-01, ...,
                   -8.38337910e-01,  2.81700072e+00, -1.79615143e-03],
                  [-1.30895965e-01,  1.17241586e-01,  3.84539054e-01, ...,
                   -9.37915897e-01,  3.24281061e+00, -6.01898963e-02]],
          
                 [[-1.08360639e-02,  2.85452910e-02,  1.13464781e-01, ...,
                    2.19398412e-02, -3.35601751e-01,  1.39910365e+00],
                  [-2.11412556e-02,  5.10284288e-02,  1.46522308e-01, ...,
                   -2.02346528e-02, -3.95950960e-01,  1.13855704e+00],
                  [ 6.82828685e-03, -6.13862840e-03,  1.27568321e-01, ...,
                    4.04350467e-02, -3.79264783e-01,  1.06575732e+00],
                  ...,
                  [ 1.37062208e-01, -5.46546800e-03,  2.94934558e-02, ...,
                   -1.61610947e-02, -2.76454821e-01,  1.24812827e+00],
                  [ 1.99152575e-01,  1.05615792e-02, -8.40147676e-03, ...,
                   -4.90571504e-02, -2.59236660e-01,  1.09376418e+00],
                  [ 1.87217560e-01,  1.45930897e-02, -3.24993821e-02, ...,
                    5.50317652e-03, -3.29970178e-01,  1.29905063e+00]]])
        • bias_out
          (chain, draw)
          float64
          1.223 0.9233 ... 0.414 0.4363
          array([[ 1.22313941,  0.9232861 ,  0.60849375, ...,  1.79778712,
                   1.84627039,  1.77414043],
                 [ 0.38886334, -0.05574717,  0.19353027, ...,  1.25810376,
                   1.18079618,  1.3629744 ],
                 [ 0.40607704,  0.93851913,  0.54525541, ..., -0.29523153,
                   0.48363728,  1.12017395],
                 [-0.0106939 ,  0.12169666, -0.02064785, ...,  0.46164929,
                   0.41395286,  0.43628482]])
        • sigma
          (chain, draw)
          float64
          0.1625 0.1663 ... 0.1813 0.1879
          array([[0.16254987, 0.16634382, 0.15773555, ..., 0.16393542, 0.16360508,
                  0.16140955],
                 [0.18353965, 0.18736972, 0.18959281, ..., 0.18416871, 0.18537189,
                  0.18616135],
                 [0.18154689, 0.17550771, 0.17969975, ..., 0.16705602, 0.17297411,
                  0.16913473],
                 [0.17724627, 0.17968671, 0.1722043 , ..., 0.18873095, 0.1812807 ,
                  0.18785447]])
        • output
          (chain, draw, output_dim_0, output_dim_1)
          float64
          0.2805 -1.072 ... -0.3061 0.01414
          array([[[[ 0.28053261],
                   [-1.07209961],
                   [-0.83003282],
                   ...,
                   [-1.00044665],
                   [-0.36370013],
                   [-0.06129262]],
          
                  [[ 0.10248303],
                   [-0.95855998],
                   [-0.72031751],
                   ...,
                   [-0.76221872],
                   [-0.34436386],
                   [-0.11356333]],
          
                  [[ 0.08002364],
                   [-1.06390717],
                   [-0.58837517],
                   ...,
          ...
                   ...,
                   [-0.87110459],
                   [-0.38842196],
                   [-0.01497993]],
          
                  [[ 0.21479285],
                   [-1.04516245],
                   [-0.69571056],
                   ...,
                   [-0.93294312],
                   [-0.42903636],
                   [ 0.10024878]],
          
                  [[-0.05022573],
                   [-1.10447567],
                   [-0.93179415],
                   ...,
                   [-1.04009348],
                   [-0.30614827],
                   [ 0.01413869]]]])
        • likelihood_output
          (chain, draw, likelihood_output_dim_0, likelihood_output_dim_1)
          float64
          0.08531 -1.094 ... -0.3619 -0.2483
          array([[[[ 0.08531019],
                   [-1.09413928],
                   [-0.87142178],
                   ...,
                   [-0.82809153],
                   [-0.36185801],
                   [-0.24833275]],
          
                  [[ 0.08531019],
                   [-1.09413928],
                   [-0.87142178],
                   ...,
                   [-0.82809153],
                   [-0.36185801],
                   [-0.24833275]],
          
                  [[ 0.08531019],
                   [-1.09413928],
                   [-0.87142178],
                   ...,
          ...
                   ...,
                   [-0.82809153],
                   [-0.36185801],
                   [-0.24833275]],
          
                  [[ 0.08531019],
                   [-1.09413928],
                   [-0.87142178],
                   ...,
                   [-0.82809153],
                   [-0.36185801],
                   [-0.24833275]],
          
                  [[ 0.08531019],
                   [-1.09413928],
                   [-0.87142178],
                   ...,
                   [-0.82809153],
                   [-0.36185801],
                   [-0.24833275]]]])
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int32', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int32', name='draw', length=1000))
        • weights_1_dim_0
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype='int32', name='weights_1_dim_0'))
        • weights_1_dim_1
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='weights_1_dim_1'))
        • bias_1_dim_0
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='bias_1_dim_0'))
        • weights_2_dim_0
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='weights_2_dim_0'))
        • weights_2_dim_1
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='weights_2_dim_1'))
        • bias_2_dim_0
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='bias_2_dim_0'))
        • weights_out_dim_0
          PandasIndex
          PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
                dtype='int32', name='weights_out_dim_0'))
        • output_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='output_dim_0', length=1651))
        • output_dim_1
          PandasIndex
          PandasIndex(Index([0], dtype='int32', name='output_dim_1'))
        • likelihood_output_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='likelihood_output_dim_0', length=1651))
        • likelihood_output_dim_1
          PandasIndex
          PandasIndex(Index([0], dtype='int32', name='likelihood_output_dim_1'))
      • created_at :
        2025-01-24T04:07:53.041500+00:00
        arviz_version :
        0.20.0
        inference_library :
        pymc
        inference_library_version :
        5.20.0
        sampling_time :
        24354.77615261078
        tuning_steps :
        4000

    • <xarray.Dataset> Size: 53MB
      Dimensions:           (chain: 4, draw: 1000, likelihood_dim_0: 1651,
                             likelihood_dim_1: 1)
      Coordinates:
        * chain             (chain) int32 16B 0 1 2 3
        * draw              (draw) int32 4kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * likelihood_dim_0  (likelihood_dim_0) int32 7kB 0 1 2 3 ... 1648 1649 1650
        * likelihood_dim_1  (likelihood_dim_1) int32 4B 0
      Data variables:
          likelihood        (chain, draw, likelihood_dim_0, likelihood_dim_1) float64 53MB ...
      Attributes:
          created_at:                 2025-01-24T04:08:05.683076+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • likelihood_dim_0: 1651
        • likelihood_dim_1: 1
        • chain
          (chain)
          int32
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int32
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • likelihood_dim_0
          (likelihood_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • likelihood_dim_1
          (likelihood_dim_1)
          int32
          0
          array([0])
        • likelihood
          (chain, draw, likelihood_dim_0, likelihood_dim_1)
          float64
          0.1766 0.8886 ... 0.7092 -0.2229
          array([[[[ 0.17663116],
                   [ 0.88863998],
                   [ 0.86541543],
                   ...,
                   [ 0.33569103],
                   [ 0.89776769],
                   [ 0.23581913]],
          
                  [[ 0.86943095],
                   [ 0.54260291],
                   [ 0.46217804],
                   ...,
                   [ 0.79635034],
                   [ 0.86922967],
                   [ 0.54655928]],
          
                  [[ 0.92733519],
                   [ 0.90952943],
                   [-0.68210629],
                   ...,
          ...
                   ...,
                   [ 0.72252353],
                   [ 0.73858895],
                   [-0.01588666]],
          
                  [[ 0.53368204],
                   [ 0.75227383],
                   [ 0.31902109],
                   ...,
                   [ 0.6215005 ],
                   [ 0.72010666],
                   [-1.05996897]],
          
                  [[ 0.49287221],
                   [ 0.7516354 ],
                   [ 0.70150715],
                   ...,
                   [ 0.11634392],
                   [ 0.70917585],
                   [-0.2229435 ]]]])
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int32', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int32', name='draw', length=1000))
        • likelihood_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='likelihood_dim_0', length=1651))
        • likelihood_dim_1
          PandasIndex
          PandasIndex(Index([0], dtype='int32', name='likelihood_dim_1'))
      • created_at :
        2025-01-24T04:08:05.683076+00:00
        arviz_version :
        0.20.0
        inference_library :
        pymc
        inference_library_version :
        5.20.0

    • <xarray.Dataset> Size: 492kB
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int32 16B 0 1 2 3
        * draw                   (draw) int32 4kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables: (12/17)
          perf_counter_diff      (chain, draw) float64 32kB 5.222 5.099 ... 3.707
          index_in_trajectory    (chain, draw) int64 32kB 422 -746 760 ... -467 711
          step_size_bar          (chain, draw) float64 32kB 0.002993 ... 0.003171
          reached_max_treedepth  (chain, draw) bool 4kB True True True ... True True
          energy_error           (chain, draw) float64 32kB 0.1027 0.01674 ... 0.07637
          acceptance_rate        (chain, draw) float64 32kB 0.9929 0.9529 ... 0.984
          ...                     ...
          energy                 (chain, draw) float64 32kB 2.195e+03 ... 2.498e+03
          process_time_diff      (chain, draw) float64 32kB 4.922 4.859 ... 3.672
          step_size              (chain, draw) float64 32kB 0.002803 ... 0.00317
          largest_eigval         (chain, draw) float64 32kB nan nan nan ... nan nan
          max_energy_error       (chain, draw) float64 32kB -0.3414 -0.305 ... -0.266
          diverging              (chain, draw) bool 4kB False False ... False False
      Attributes:
          created_at:                 2025-01-24T04:07:53.071503+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0
          sampling_time:              24354.77615261078
          tuning_steps:               4000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int32
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int32
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999])
        • perf_counter_diff
          (chain, draw)
          float64
          5.222 5.099 4.961 ... 4.085 3.707
          array([[5.2224143, 5.0990823, 4.9609629, ..., 4.3166212, 4.0233637,
                  4.3783982],
                 [4.9663063, 5.0196738, 4.9407594, ..., 3.5661953, 3.306683 ,
                  3.5758385],
                 [5.0383824, 5.0919306, 4.7363706, ..., 4.9146839, 4.8747666,
                  4.5860634],
                 [4.9603351, 4.742702 , 4.7337003, ..., 4.3061316, 4.0848552,
                  3.7073948]])
        • index_in_trajectory
          (chain, draw)
          int64
          422 -746 760 -348 ... -487 -467 711
          array([[ 422, -746,  760, ..., -375, -277,  368],
                 [-477, -441,  519, ..., -540,  689, -278],
                 [ 191, -537,  410, ...,  769, -696, -644],
                 [-264,  767, -852, ..., -487, -467,  711]], dtype=int64)
        • step_size_bar
          (chain, draw)
          float64
          0.002993 0.002993 ... 0.003171
          array([[0.0029934 , 0.0029934 , 0.0029934 , ..., 0.0029934 , 0.0029934 ,
                  0.0029934 ],
                 [0.00279618, 0.00279618, 0.00279618, ..., 0.00279618, 0.00279618,
                  0.00279618],
                 [0.00249535, 0.00249535, 0.00249535, ..., 0.00249535, 0.00249535,
                  0.00249535],
                 [0.00317117, 0.00317117, 0.00317117, ..., 0.00317117, 0.00317117,
                  0.00317117]])
        • reached_max_treedepth
          (chain, draw)
          bool
          True True True ... True True True
          array([[ True,  True,  True, ...,  True,  True,  True],
                 [ True,  True,  True, ...,  True,  True,  True],
                 [ True,  True,  True, ...,  True,  True,  True],
                 [ True,  True,  True, ...,  True,  True,  True]])
        • energy_error
          (chain, draw)
          float64
          0.1027 0.01674 ... -0.1592 0.07637
          array([[ 0.10270985,  0.01673711,  0.20526556, ..., -0.19813523,
                   0.07555737, -0.01397414],
                 [-0.01685923,  0.19174112, -0.5031419 , ...,  0.09607443,
                   0.00181521,  0.05428212],
                 [ 0.06199411, -0.06856122, -0.03346366, ..., -0.11597096,
                   0.06214703,  0.04984836],
                 [ 0.15174472, -0.02023078,  0.15623833, ...,  0.13440444,
                  -0.15922876,  0.07637308]])
        • acceptance_rate
          (chain, draw)
          float64
          0.9929 0.9529 ... 0.9966 0.984
          array([[0.99290756, 0.95290219, 0.8705762 , ..., 0.98165567, 0.93046535,
                  0.95660358],
                 [0.96051724, 0.85909354, 0.97730259, ..., 0.86451597, 0.98869539,
                  0.97782296],
                 [0.93559425, 0.99382987, 0.9915182 , ..., 0.99363132, 0.99316139,
                  0.96415292],
                 [0.87639649, 0.99077679, 0.79609607, ..., 0.96526755, 0.99660888,
                  0.98399989]])
        • n_steps
          (chain, draw)
          float64
          1.023e+03 1.023e+03 ... 1.023e+03
          array([[1023., 1023., 1023., ..., 1023., 1023., 1023.],
                 [1023., 1023., 1023., ..., 1023., 1023., 1023.],
                 [1023., 1023., 1023., ..., 1023., 1023., 1023.],
                 [1023., 1023., 1023., ..., 1023., 1023., 1023.]])
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]])
        • perf_counter_start
          (chain, draw)
          float64
          6.3e+05 6.3e+05 ... 6.35e+05
          array([[629987.6799434, 629992.9026028, 629998.0019734, ...,
                  634942.2594369, 634946.5762616, 634950.5998434],
                 [630096.9712949, 630101.9379222, 630106.9578705, ...,
                  634999.45269  , 635003.0191826, 635006.3260679],
                 [629987.1318287, 629992.170522 , 629997.2629476, ...,
                  634909.6002821, 634914.51519  , 634919.3902664],
                 [630029.5899648, 630034.5505986, 630039.2935884, ...,
                  634949.4814616, 634953.7878201, 634957.8728708]])
        • lp
          (chain, draw)
          float64
          -1.527e+03 -1.548e+03 ... -1.76e+03
          array([[-1527.01698205, -1547.95934159, -1549.59806162, ...,
                  -1532.02486355, -1547.19711405, -1569.38962665],
                 [-1661.40289951, -1702.65908444, -1700.4066471 , ...,
                  -1696.32471799, -1686.86658564, -1732.88154285],
                 [-1639.4033479 , -1612.55464688, -1620.81799941, ...,
                  -1548.55792802, -1572.98867751, -1563.83311634],
                 [-1638.33461353, -1627.14085801, -1613.20134564, ...,
                  -1746.29533209, -1746.14623703, -1760.45936425]])
        • tree_depth
          (chain, draw)
          int64
          10 10 10 10 10 ... 10 10 10 10 10
          array([[10, 10, 10, ..., 10, 10, 10],
                 [10, 10, 10, ..., 10, 10, 10],
                 [10, 10, 10, ..., 10, 10, 10],
                 [10, 10, 10, ..., 10, 10, 10]], dtype=int64)
        • energy
          (chain, draw)
          float64
          2.195e+03 2.276e+03 ... 2.498e+03
          array([[2194.5523502 , 2275.70585712, 2259.67419831, ..., 2257.16372635,
                  2244.13772384, 2272.15381773],
                 [2379.29128246, 2364.71378081, 2430.03891088, ..., 2402.04306536,
                  2390.194083  , 2376.33159293],
                 [2327.40305794, 2354.66209741, 2338.28912179, ..., 2297.97501518,
                  2236.95585441, 2228.08836679],
                 [2368.51965977, 2318.37731257, 2327.06085251, ..., 2420.625324  ,
                  2444.07482804, 2497.97697731]])
        • process_time_diff
          (chain, draw)
          float64
          4.922 4.859 4.797 ... 3.953 3.672
          array([[4.921875, 4.859375, 4.796875, ..., 4.140625, 4.015625, 4.21875 ],
                 [4.734375, 4.8125  , 4.75    , ..., 3.40625 , 3.3125  , 3.421875],
                 [4.84375 , 4.84375 , 4.734375, ..., 4.796875, 4.734375, 4.53125 ],
                 [4.734375, 4.640625, 4.71875 , ..., 4.15625 , 3.953125, 3.671875]])
        • step_size
          (chain, draw)
          float64
          0.002803 0.002803 ... 0.00317
          array([[0.00280318, 0.00280318, 0.00280318, ..., 0.00280318, 0.00280318,
                  0.00280318],
                 [0.00261901, 0.00261901, 0.00261901, ..., 0.00261901, 0.00261901,
                  0.00261901],
                 [0.00237044, 0.00237044, 0.00237044, ..., 0.00237044, 0.00237044,
                  0.00237044],
                 [0.00317017, 0.00317017, 0.00317017, ..., 0.00317017, 0.00317017,
                  0.00317017]])
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]])
        • max_energy_error
          (chain, draw)
          float64
          -0.3414 -0.305 ... -0.429 -0.266
          array([[-0.34136233, -0.30503702,  0.36813893, ..., -0.41608003,
                   0.26928811, -0.28993222],
                 [-0.33726732,  0.50274536, -0.54551177, ...,  0.39647597,
                  -0.16545548, -0.23810795],
                 [ 0.2679499 , -0.49122377, -0.48987172, ..., -0.4673793 ,
                  -0.28154161, -0.38309645],
                 [ 0.51470694, -0.30231338,  0.53689515, ..., -0.23066551,
                  -0.42900919, -0.26601292]])
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]])
        • chain
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3], dtype='int32', name='chain'))
        • draw
          PandasIndex
          PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
                 ...
                 990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
                dtype='int32', name='draw', length=1000))
      • created_at :
        2025-01-24T04:07:53.071503+00:00
        arviz_version :
        0.20.0
        inference_library :
        pymc
        inference_library_version :
        5.20.0
        sampling_time :
        24354.77615261078
        tuning_steps :
        4000

    • <xarray.Dataset> Size: 20kB
      Dimensions:           (likelihood_dim_0: 1651, likelihood_dim_1: 1)
      Coordinates:
        * likelihood_dim_0  (likelihood_dim_0) int32 7kB 0 1 2 3 ... 1648 1649 1650
        * likelihood_dim_1  (likelihood_dim_1) int32 4B 0
      Data variables:
          likelihood        (likelihood_dim_0, likelihood_dim_1) float64 13kB 0.085...
      Attributes:
          created_at:                 2025-01-24T04:07:53.079503+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0
      xarray.Dataset
        • likelihood_dim_0: 1651
        • likelihood_dim_1: 1
        • likelihood_dim_0
          (likelihood_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • likelihood_dim_1
          (likelihood_dim_1)
          int32
          0
          array([0])
        • likelihood
          (likelihood_dim_0, likelihood_dim_1)
          float64
          0.08531 -1.094 ... -0.3619 -0.2483
          array([[ 0.08531019],
                 [-1.09413928],
                 [-0.87142178],
                 ...,
                 [-0.82809153],
                 [-0.36185801],
                 [-0.24833275]])
        • likelihood_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='likelihood_dim_0', length=1651))
        • likelihood_dim_1
          PandasIndex
          PandasIndex(Index([0], dtype='int32', name='likelihood_dim_1'))
      • created_at :
        2025-01-24T04:07:53.079503+00:00
        arviz_version :
        0.20.0
        inference_library :
        pymc
        inference_library_version :
        5.20.0

    • <xarray.Dataset> Size: 145kB
      Dimensions:                (features_params_dim_0: 1651,
                                  features_params_dim_1: 9,
                                  targets_params_dim_0: 1651, targets_params_dim_1: 1)
      Coordinates:
        * features_params_dim_0  (features_params_dim_0) int32 7kB 0 1 2 ... 1649 1650
        * features_params_dim_1  (features_params_dim_1) int32 36B 0 1 2 3 4 5 6 7 8
        * targets_params_dim_0   (targets_params_dim_0) int32 7kB 0 1 2 ... 1649 1650
        * targets_params_dim_1   (targets_params_dim_1) int32 4B 0
      Data variables:
          features_params        (features_params_dim_0, features_params_dim_1) float64 119kB ...
          targets_params         (targets_params_dim_0, targets_params_dim_1) float64 13kB ...
      Attributes:
          created_at:                 2025-01-24T04:07:53.081503+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0
      xarray.Dataset
        • features_params_dim_0: 1651
        • features_params_dim_1: 9
        • targets_params_dim_0: 1651
        • targets_params_dim_1: 1
        • features_params_dim_0
          (features_params_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • features_params_dim_1
          (features_params_dim_1)
          int32
          0 1 2 3 4 5 6 7 8
          array([0, 1, 2, 3, 4, 5, 6, 7, 8])
        • targets_params_dim_0
          (targets_params_dim_0)
          int32
          0 1 2 3 4 ... 1647 1648 1649 1650
          array([   0,    1,    2, ..., 1648, 1649, 1650])
        • targets_params_dim_1
          (targets_params_dim_1)
          int32
          0
          array([0])
        • features_params
          (features_params_dim_0, features_params_dim_1)
          float64
          0.6936 -0.8483 ... -0.2872 -0.8207
          array([[ 0.69364418, -0.84828919,  1.14105882, ..., -0.21327573,
                  -0.19511304,  1.99500282],
                 [-1.33282653,  2.25578777, -0.13027044, ..., -0.5219173 ,
                  -0.14363327, -0.11673923],
                 [ 0.68366157, -0.87169852,  1.06160074, ..., -0.75209069,
                  -0.8662449 ,  1.99500282],
                 ...,
                 [-0.05006059,  0.56095239, -2.03726433, ..., -0.41729304,
                   0.29736616, -0.11673923],
                 [ 0.74355725, -0.79678867,  0.42593611, ..., -0.42513986,
                  -0.34313051, -0.82065324],
                 [ 0.5638702 , -0.66101456,  0.8232265 , ..., -0.48529881,
                  -0.28722917, -0.82065324]])
        • targets_params
          (targets_params_dim_0, targets_params_dim_1)
          float64
          0.08531 -1.094 ... -0.3619 -0.2483
          array([[ 0.08531019],
                 [-1.09413928],
                 [-0.87142178],
                 ...,
                 [-0.82809153],
                 [-0.36185801],
                 [-0.24833275]])
        • features_params_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='features_params_dim_0', length=1651))
        • features_params_dim_1
          PandasIndex
          PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype='int32', name='features_params_dim_1'))
        • targets_params_dim_0
          PandasIndex
          PandasIndex(Index([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
                 ...
                 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650],
                dtype='int32', name='targets_params_dim_0', length=1651))
        • targets_params_dim_1
          PandasIndex
          PandasIndex(Index([0], dtype='int32', name='targets_params_dim_1'))
      • created_at :
        2025-01-24T04:07:53.081503+00:00
        arviz_version :
        0.20.0
        inference_library :
        pymc
        inference_library_version :
        5.20.0

In [22]:
posterior_pred_train.posterior_predictive["likelihood_output"].shape
Out[22]:
(4, 1000, 1651, 1)
In [ ]:
 
In [51]:
predictions_train = posterior_pred_train.posterior_predictive["likelihood_output"].mean(axis=(0, 1)).values
predictions_train += 1e-9
predictions_train_reshaped = predictions_train.reshape(-1, 1)
predictions_train_original_scale = target_scaler.inverse_transform(predictions_train_reshaped).flatten()

print("Predictions in original scale:", predictions_train_original_scale[:10])
Predictions in original scale: [217342.63355987  90308.14254984 115919.91783123 160124.88126094
 184249.53096355 230531.95613715 195922.19899313 276006.84772789
 500123.87462741 421481.23234539]
In [52]:
uncertainty_train = posterior_pred_train.posterior_predictive["likelihood_output"].std(axis=(0, 1))
threshold_train = uncertainty_train.mean() + 2 * uncertainty_train.std()
extrapolation_points_train = uncertainty_train > threshold_train
extrapolation_points_train = extrapolation_points_train.values.flatten()

predictive_entropy = entropy(predictions_train_original_scale, axis=0)
threshold_entropy = predictive_entropy.mean() + 2 * predictive_entropy.std()
extrapolation_points_entropy = predictive_entropy > threshold_entropy

print("Train extrapolation regions identified:", np.sum(extrapolation_points_train))
print("Train extrapolation regions identified with entropy:", np.sum(extrapolation_points_entropy))
Train extrapolation regions identified: 18
Train extrapolation regions identified with entropy: 0
In [54]:
print("Shape of training_targets:", training_targets.shape)
print("Shape of predictions:", predictions_train_original_scale.shape)
print("Shape of extrapolation_points:", extrapolation_points_train.shape)
Shape of training_targets: (1651, 1)
Shape of predictions: (1651,)
Shape of extrapolation_points: (1651,)
In [130]:
plt.figure(figsize=(10, 6))
plt.scatter(range(len(predictions_train_original_scale)), predictions_train_original_scale, label="Predictions", alpha=0.7)
plt.scatter(range(len(training_targets)), target_scaler.inverse_transform(training_targets).flatten(), label="True Values", alpha=0.7)
plt.scatter(
    np.where(extrapolation_points_train)[0], 
    predictions_train_original_scale[extrapolation_points_train],
    color="green", label="Extrapolation Points", alpha=0.7
)
plt.legend()
plt.title("Predictions and Extrapolation Points")
plt.xlabel("Sample Index")
plt.ylabel("Median House Value")
plt.show()
No description has been provided for this image

Evaluating it on testing subset¶

In [58]:
with regression_bayesian_network_model:
    pm.set_data({"features_params": testing_features, "targets_params": testing_targets})
    posterior_pred_test = pm.sample_posterior_predictive(
        regression_bayesian_network_trace, var_names=["likelihood_output"]
    )

posterior_pred_test.posterior_predictive
Sampling: [likelihood]
Output()

Out[58]:
<xarray.Dataset> Size: 13MB
Dimensions:                  (chain: 4, draw: 1000,
                              likelihood_output_dim_2: 413,
                              likelihood_output_dim_3: 1)
Coordinates:
  * chain                    (chain) int32 16B 0 1 2 3
  * draw                     (draw) int32 4kB 0 1 2 3 4 ... 995 996 997 998 999
  * likelihood_output_dim_2  (likelihood_output_dim_2) int32 2kB 0 1 ... 411 412
  * likelihood_output_dim_3  (likelihood_output_dim_3) int32 4B 0
Data variables:
    likelihood_output        (chain, draw, likelihood_output_dim_2, likelihood_output_dim_3) float64 13MB ...
Attributes:
    created_at:                 2025-01-24T10:44:44.475911+00:00
    arviz_version:              0.20.0
    inference_library:          pymc
    inference_library_version:  5.20.0
xarray.Dataset
    • chain: 4
    • draw: 1000
    • likelihood_output_dim_2: 413
    • likelihood_output_dim_3: 1
    • chain
      (chain)
      int32
      0 1 2 3
      array([0, 1, 2, 3])
    • draw
      (draw)
      int32
      0 1 2 3 4 5 ... 995 996 997 998 999
      array([  0,   1,   2, ..., 997, 998, 999])
    • likelihood_output_dim_2
      (likelihood_output_dim_2)
      int32
      0 1 2 3 4 5 ... 408 409 410 411 412
      array([  0,   1,   2, ..., 410, 411, 412])
    • likelihood_output_dim_3
      (likelihood_output_dim_3)
      int32
      0
      array([0])
    • likelihood_output
      (chain, draw, likelihood_output_dim_2, likelihood_output_dim_3)
      float64
      1.406 1.846 ... -1.006 -0.3212
      array([[[[ 1.40640302e+00],
               [ 1.84550724e+00],
               [-1.09514386e+00],
               ...,
               [-1.20529123e+00],
               [-1.36791293e+00],
               [-4.51869220e-01]],
      
              [[ 1.34327674e+00],
               [ 8.99120520e-01],
               [-1.55827121e+00],
               ...,
               [-1.45587563e+00],
               [-1.37115875e+00],
               [-7.63397055e-01]],
      
              [[ 5.40486443e-01],
               [ 1.58735562e+00],
               [-1.19116003e+00],
               ...,
      ...
               ...,
               [-1.08081387e+00],
               [-1.17887985e+00],
               [-8.77780463e-01]],
      
              [[ 1.51214229e-01],
               [ 3.01196236e+00],
               [-1.43276798e+00],
               ...,
               [-1.84231592e+00],
               [-1.14691116e+00],
               [-3.89327888e-01]],
      
              [[-2.47018646e-01],
               [ 2.58320998e+00],
               [-1.50838257e+00],
               ...,
               [-1.13980906e+00],
               [-1.00608663e+00],
               [-3.21213934e-01]]]])
    • chain
      PandasIndex
      PandasIndex(Index([0, 1, 2, 3], dtype='int32', name='chain'))
    • draw
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
            dtype='int32', name='draw', length=1000))
    • likelihood_output_dim_2
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             403, 404, 405, 406, 407, 408, 409, 410, 411, 412],
            dtype='int32', name='likelihood_output_dim_2', length=413))
    • likelihood_output_dim_3
      PandasIndex
      PandasIndex(Index([0], dtype='int32', name='likelihood_output_dim_3'))
  • created_at :
    2025-01-24T10:44:44.475911+00:00
    arviz_version :
    0.20.0
    inference_library :
    pymc
    inference_library_version :
    5.20.0
In [59]:
posterior_pred_test.posterior_predictive["likelihood_output"].shape
Out[59]:
(4, 1000, 413, 1)
In [60]:
predictions_test = posterior_pred_test.posterior_predictive["likelihood_output"].mean(axis=(0, 1)).values
predictions_test += 1e-9  
predictions_test_reshaped = predictions_test.reshape(-1, 1)
predictions_test_original_scale = target_scaler.inverse_transform(predictions_test_reshaped).flatten()

print("Predictions for testing in original scale:", predictions_test_original_scale[:10])
Predictions for testing in original scale: [234730.68734274 409353.95600658  57025.93738655  99919.39064562
 239081.55686621 117202.12889563 469941.4366122  135659.8159613
 210074.71760409 475256.15052144]
In [61]:
uncertainty_test = posterior_pred_test.posterior_predictive["likelihood_output"].std(axis=(0, 1))
threshold_test = uncertainty_test.mean() + 2 * uncertainty_test.std()
extrapolation_points_test = uncertainty_test > threshold_test
extrapolation_points_test = extrapolation_points_test.values.flatten()

predictive_entropy_test = entropy(predictions_test_original_scale, axis=0)
threshold_entropy_test = predictive_entropy_test.mean() + 2 * predictive_entropy_test.std()
extrapolation_points_entropy_test = predictive_entropy_test > threshold_entropy_test

print("Test extrapolation regions identified:", np.sum(extrapolation_points_test))
print("Test extrapolation regions identified with entropy:", np.sum(extrapolation_points_entropy_test))
Test extrapolation regions identified: 18
Test extrapolation regions identified with entropy: 0
In [62]:
print("Shape of training_targets:", testing_targets.shape)
print("Shape of predictions:", predictions_test_original_scale.shape)
print("Shape of extrapolation_points:", extrapolation_points_test.shape)
Shape of training_targets: (413, 1)
Shape of predictions: (413,)
Shape of extrapolation_points: (413,)
In [63]:
plt.figure(figsize=(10, 6))
plt.scatter(
    range(len(predictions_test_original_scale)),
    predictions_test_original_scale,
    label="Predictions", alpha=0.7
)
plt.scatter(
    range(len(testing_targets)),
    target_scaler.inverse_transform(testing_targets).flatten(),
    label="True Values", alpha=0.7
)
plt.scatter(
    np.where(extrapolation_points_test)[0],
    predictions_test_original_scale[extrapolation_points_test],
    color="green", label="Extrapolation Points", alpha=0.7
)
plt.legend()
plt.title("Testing Predictions and Extrapolation Points")
plt.xlabel("Sample Index")
plt.ylabel("Median House Value")
plt.show()
No description has been provided for this image

Conclusions¶

How to design the BNN using pymc and compare unseen classes¶

Before using entropy, I attempted a simple accuracy score for predictions sampling (would give a vector with values sample from the multivariable distribution obtained). That being said, I still fail to see how to make that comparison.

Normalization of data helps BNN¶

This is what I've observed, at least for the regression problem. The remaining step is to revert the transformation at the end. But these additional steps are worth it for the accuracy of the predictions

Subset the data, or decrease features¶

Sampling can take very long, and depending on the number of chains also, convergence might take very long to be attained. A compromise can be obtained in truncating data (less images, or less rows in regression), or decrease features(decrease image resolution, or get rid of columns that have very little influence over the target in regression).

Bayesian Neural Networks vs Classical Neural Networks¶

The latter are faster to train and do inference on (the sampling time for the BNN is high). But how to measure uncertainty on classical networks? Well, comparison between accuracies isn't a good measure. However, BNNs explicitly model uncertainty by treating the weights of the network as random variables with prior probability distributions, rather than as fixed values and there is a measure based on entropy for comparison, for example. At least that should be the theory.